Add unit tests for C layer for the input types of Text Classifier
PiperOrigin-RevId: 569553038
This commit is contained in:
		
							parent
							
								
									6915a79e28
								
							
						
					
					
						commit
						96fa10b906
					
				| 
						 | 
				
			
			@ -30,3 +30,15 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "classifier_options_converter_test",
 | 
			
		||||
    srcs = ["classifier_options_converter_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":classifier_options",
 | 
			
		||||
        ":classifier_options_converter",
 | 
			
		||||
        "//mediapipe/framework/port:gtest",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
			
		||||
        "@com_google_googletest//:gtest_main",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ extern "C" {
 | 
			
		|||
struct ClassifierOptions {
 | 
			
		||||
  // The locale to use for display names specified through the TFLite Model
 | 
			
		||||
  // Metadata, if any. Defaults to English.
 | 
			
		||||
  char* display_names_locale;
 | 
			
		||||
  const char* display_names_locale;
 | 
			
		||||
 | 
			
		||||
  // The maximum number of top-scored classification results to return. If < 0,
 | 
			
		||||
  // all available results will be returned. If 0, an invalid argument error is
 | 
			
		||||
| 
						 | 
				
			
			@ -40,14 +40,14 @@ struct ClassifierOptions {
 | 
			
		|||
  // The allowlist of category names. If non-empty, detection results whose
 | 
			
		||||
  // category name is not in this set will be filtered out. Duplicate or unknown
 | 
			
		||||
  // category names are ignored. Mutually exclusive with category_denylist.
 | 
			
		||||
  char** category_allowlist;
 | 
			
		||||
  const char** category_allowlist;
 | 
			
		||||
  // The number of elements in the category allowlist.
 | 
			
		||||
  uint32_t category_allowlist_count;
 | 
			
		||||
 | 
			
		||||
  // The denylist of category names. If non-empty, detection results whose
 | 
			
		||||
  // category name is in this set will be filtered out. Duplicate or unknown
 | 
			
		||||
  // category names are ignored. Mutually exclusive with category_allowlist.
 | 
			
		||||
  char** category_denylist;
 | 
			
		||||
  const char** category_denylist;
 | 
			
		||||
  // The number of elements in the category denylist.
 | 
			
		||||
  uint32_t category_denylist_count;
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ void CppConvertToClassifierOptions(
 | 
			
		|||
    const ClassifierOptions& in,
 | 
			
		||||
    mediapipe::tasks::components::processors::ClassifierOptions* out) {
 | 
			
		||||
  out->display_names_locale =
 | 
			
		||||
      in.display_names_locale ? std::string(in.display_names_locale) : "";
 | 
			
		||||
      in.display_names_locale ? std::string(in.display_names_locale) : "en";
 | 
			
		||||
  out->max_results = in.max_results;
 | 
			
		||||
  out->score_threshold = in.score_threshold;
 | 
			
		||||
  out->category_allowlist =
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,82 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::processors {
 | 
			
		||||
 | 
			
		||||
constexpr char kCategoryAllowlist[] = "fruit";
 | 
			
		||||
constexpr char kCategoryDenylist[] = "veggies";
 | 
			
		||||
constexpr char kDisplayNamesLocaleGerman[] = "de";
 | 
			
		||||
 | 
			
		||||
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsCustomValues) {
 | 
			
		||||
  std::vector<const char*> category_allowlist = {kCategoryAllowlist};
 | 
			
		||||
  std::vector<const char*> category_denylist = {kCategoryDenylist};
 | 
			
		||||
 | 
			
		||||
  ClassifierOptions c_classifier_options = {
 | 
			
		||||
      /* display_names_locale= */ kDisplayNamesLocaleGerman,
 | 
			
		||||
      /* max_results= */ 1,
 | 
			
		||||
      /* score_threshold= */ 0.1,
 | 
			
		||||
      /* category_allowlist= */ category_allowlist.data(),
 | 
			
		||||
      /* category_allowlist_count= */ 1,
 | 
			
		||||
      /* category_denylist= */ category_denylist.data(),
 | 
			
		||||
      /* category_denylist_count= */ 1};
 | 
			
		||||
 | 
			
		||||
  mediapipe::tasks::components::processors::ClassifierOptions
 | 
			
		||||
      cpp_classifier_options = {};
 | 
			
		||||
 | 
			
		||||
  CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.display_names_locale, "de");
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.max_results, 1);
 | 
			
		||||
  EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.1);
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.category_allowlist,
 | 
			
		||||
            std::vector<std::string>{"fruit"});
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.category_denylist,
 | 
			
		||||
            std::vector<std::string>{"veggies"});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsDefaultValues) {
 | 
			
		||||
  std::vector<const char*> category_allowlist = {kCategoryAllowlist};
 | 
			
		||||
  std::vector<const char*> category_denylist = {kCategoryDenylist};
 | 
			
		||||
 | 
			
		||||
  ClassifierOptions c_classifier_options = {/* display_names_locale= */ nullptr,
 | 
			
		||||
                                            /* max_results= */ -1,
 | 
			
		||||
                                            /* score_threshold= */ 0.0,
 | 
			
		||||
                                            /* category_allowlist= */ nullptr,
 | 
			
		||||
                                            /* category_allowlist_count= */ 0,
 | 
			
		||||
                                            /* category_denylist= */ nullptr,
 | 
			
		||||
                                            /* category_denylist_count= */ 0};
 | 
			
		||||
 | 
			
		||||
  mediapipe::tasks::components::processors::ClassifierOptions
 | 
			
		||||
      cpp_classifier_options = {};
 | 
			
		||||
 | 
			
		||||
  CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.display_names_locale, "en");
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.max_results, -1);
 | 
			
		||||
  EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.0);
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.category_allowlist,
 | 
			
		||||
            std::vector<std::string>{});
 | 
			
		||||
  EXPECT_EQ(cpp_classifier_options.category_denylist,
 | 
			
		||||
            std::vector<std::string>{});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::processors
 | 
			
		||||
| 
						 | 
				
			
			@ -30,3 +30,15 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/core:base_options",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "base_options_converter_test",
 | 
			
		||||
    srcs = ["base_options_converter_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":base_options",
 | 
			
		||||
        ":base_options_converter",
 | 
			
		||||
        "//mediapipe/framework/port:gtest",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:base_options",
 | 
			
		||||
        "@com_google_googletest//:gtest_main",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,10 +23,10 @@ extern "C" {
 | 
			
		|||
// Base options for MediaPipe C Tasks.
 | 
			
		||||
struct BaseOptions {
 | 
			
		||||
  // The model asset file contents as a string.
 | 
			
		||||
  char* model_asset_buffer;
 | 
			
		||||
  const char* model_asset_buffer;
 | 
			
		||||
 | 
			
		||||
  // The path to the model asset to open and mmap in memory.
 | 
			
		||||
  char* model_asset_path;
 | 
			
		||||
  const char* model_asset_path;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#ifdef __cplusplus
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,7 +21,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/tasks/c/core/base_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::containers {
 | 
			
		||||
namespace mediapipe::tasks::c::core {
 | 
			
		||||
 | 
			
		||||
void CppConvertToBaseOptions(const BaseOptions& in,
 | 
			
		||||
                             mediapipe::tasks::core::BaseOptions* out) {
 | 
			
		||||
| 
						 | 
				
			
			@ -33,4 +33,4 @@ void CppConvertToBaseOptions(const BaseOptions& in,
 | 
			
		|||
      in.model_asset_path ? std::string(in.model_asset_path) : "";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::containers
 | 
			
		||||
}  // namespace mediapipe::tasks::c::core
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,11 +19,11 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/tasks/c/core/base_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::containers {
 | 
			
		||||
namespace mediapipe::tasks::c::core {
 | 
			
		||||
 | 
			
		||||
void CppConvertToBaseOptions(const BaseOptions& in,
 | 
			
		||||
                             mediapipe::tasks::core::BaseOptions* out);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::containers
 | 
			
		||||
}  // namespace mediapipe::tasks::c::core
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										51
									
								
								mediapipe/tasks/c/core/base_options_converter_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								mediapipe/tasks/c/core/base_options_converter_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,51 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/tasks/c/core/base_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::core {
 | 
			
		||||
 | 
			
		||||
constexpr char kAssetBuffer[] = "abc";
 | 
			
		||||
constexpr char kModelAssetPath[] = "abc.tflite";
 | 
			
		||||
 | 
			
		||||
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
 | 
			
		||||
  BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
 | 
			
		||||
                                /* model_asset_path= */ nullptr};
 | 
			
		||||
 | 
			
		||||
  mediapipe::tasks::core::BaseOptions cpp_base_options = {};
 | 
			
		||||
 | 
			
		||||
  CppConvertToBaseOptions(c_base_options, &cpp_base_options);
 | 
			
		||||
  EXPECT_EQ(*cpp_base_options.model_asset_buffer, std::string{kAssetBuffer});
 | 
			
		||||
  EXPECT_EQ(cpp_base_options.model_asset_path, "");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
 | 
			
		||||
  BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
 | 
			
		||||
                                /* model_asset_path= */ kModelAssetPath};
 | 
			
		||||
 | 
			
		||||
  mediapipe::tasks::core::BaseOptions cpp_base_options = {};
 | 
			
		||||
 | 
			
		||||
  CppConvertToBaseOptions(c_base_options, &cpp_base_options);
 | 
			
		||||
  EXPECT_EQ(cpp_base_options.model_asset_buffer.get(), nullptr);
 | 
			
		||||
  EXPECT_EQ(cpp_base_options.model_asset_path, std::string{kModelAssetPath});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::core
 | 
			
		||||
| 
						 | 
				
			
			@ -30,11 +30,11 @@ namespace mediapipe::tasks::c::text::text_classifier {
 | 
			
		|||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions;
 | 
			
		||||
using ::mediapipe::tasks::c::components::containers::
 | 
			
		||||
    CppConvertToClassificationResult;
 | 
			
		||||
using ::mediapipe::tasks::c::components::processors::
 | 
			
		||||
    CppConvertToClassifierOptions;
 | 
			
		||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
 | 
			
		||||
using ::mediapipe::tasks::text::text_classifier::TextClassifier;
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user