Add unit tests for C layer for the input types of Text Classifier
PiperOrigin-RevId: 569553038
This commit is contained in:
		
							parent
							
								
									6915a79e28
								
							
						
					
					
						commit
						96fa10b906
					
				| 
						 | 
					@ -30,3 +30,15 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
					        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_test(
 | 
				
			||||||
 | 
					    name = "classifier_options_converter_test",
 | 
				
			||||||
 | 
					    srcs = ["classifier_options_converter_test.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":classifier_options",
 | 
				
			||||||
 | 
					        ":classifier_options_converter",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "@com_google_googletest//:gtest_main",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,7 +26,7 @@ extern "C" {
 | 
				
			||||||
struct ClassifierOptions {
 | 
					struct ClassifierOptions {
 | 
				
			||||||
  // The locale to use for display names specified through the TFLite Model
 | 
					  // The locale to use for display names specified through the TFLite Model
 | 
				
			||||||
  // Metadata, if any. Defaults to English.
 | 
					  // Metadata, if any. Defaults to English.
 | 
				
			||||||
  char* display_names_locale;
 | 
					  const char* display_names_locale;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The maximum number of top-scored classification results to return. If < 0,
 | 
					  // The maximum number of top-scored classification results to return. If < 0,
 | 
				
			||||||
  // all available results will be returned. If 0, an invalid argument error is
 | 
					  // all available results will be returned. If 0, an invalid argument error is
 | 
				
			||||||
| 
						 | 
					@ -40,14 +40,14 @@ struct ClassifierOptions {
 | 
				
			||||||
  // The allowlist of category names. If non-empty, detection results whose
 | 
					  // The allowlist of category names. If non-empty, detection results whose
 | 
				
			||||||
  // category name is not in this set will be filtered out. Duplicate or unknown
 | 
					  // category name is not in this set will be filtered out. Duplicate or unknown
 | 
				
			||||||
  // category names are ignored. Mutually exclusive with category_denylist.
 | 
					  // category names are ignored. Mutually exclusive with category_denylist.
 | 
				
			||||||
  char** category_allowlist;
 | 
					  const char** category_allowlist;
 | 
				
			||||||
  // The number of elements in the category allowlist.
 | 
					  // The number of elements in the category allowlist.
 | 
				
			||||||
  uint32_t category_allowlist_count;
 | 
					  uint32_t category_allowlist_count;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The denylist of category names. If non-empty, detection results whose
 | 
					  // The denylist of category names. If non-empty, detection results whose
 | 
				
			||||||
  // category name is in this set will be filtered out. Duplicate or unknown
 | 
					  // category name is in this set will be filtered out. Duplicate or unknown
 | 
				
			||||||
  // category names are ignored. Mutually exclusive with category_allowlist.
 | 
					  // category names are ignored. Mutually exclusive with category_allowlist.
 | 
				
			||||||
  char** category_denylist;
 | 
					  const char** category_denylist;
 | 
				
			||||||
  // The number of elements in the category denylist.
 | 
					  // The number of elements in the category denylist.
 | 
				
			||||||
  uint32_t category_denylist_count;
 | 
					  uint32_t category_denylist_count;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,7 +26,7 @@ void CppConvertToClassifierOptions(
 | 
				
			||||||
    const ClassifierOptions& in,
 | 
					    const ClassifierOptions& in,
 | 
				
			||||||
    mediapipe::tasks::components::processors::ClassifierOptions* out) {
 | 
					    mediapipe::tasks::components::processors::ClassifierOptions* out) {
 | 
				
			||||||
  out->display_names_locale =
 | 
					  out->display_names_locale =
 | 
				
			||||||
      in.display_names_locale ? std::string(in.display_names_locale) : "";
 | 
					      in.display_names_locale ? std::string(in.display_names_locale) : "en";
 | 
				
			||||||
  out->max_results = in.max_results;
 | 
					  out->max_results = in.max_results;
 | 
				
			||||||
  out->score_threshold = in.score_threshold;
 | 
					  out->score_threshold = in.score_threshold;
 | 
				
			||||||
  out->category_allowlist =
 | 
					  out->category_allowlist =
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,82 @@
 | 
				
			||||||
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/c/components/processors/classifier_options.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::tasks::c::components::processors {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					constexpr char kCategoryAllowlist[] = "fruit";
 | 
				
			||||||
 | 
					constexpr char kCategoryDenylist[] = "veggies";
 | 
				
			||||||
 | 
					constexpr char kDisplayNamesLocaleGerman[] = "de";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsCustomValues) {
 | 
				
			||||||
 | 
					  std::vector<const char*> category_allowlist = {kCategoryAllowlist};
 | 
				
			||||||
 | 
					  std::vector<const char*> category_denylist = {kCategoryDenylist};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ClassifierOptions c_classifier_options = {
 | 
				
			||||||
 | 
					      /* display_names_locale= */ kDisplayNamesLocaleGerman,
 | 
				
			||||||
 | 
					      /* max_results= */ 1,
 | 
				
			||||||
 | 
					      /* score_threshold= */ 0.1,
 | 
				
			||||||
 | 
					      /* category_allowlist= */ category_allowlist.data(),
 | 
				
			||||||
 | 
					      /* category_allowlist_count= */ 1,
 | 
				
			||||||
 | 
					      /* category_denylist= */ category_denylist.data(),
 | 
				
			||||||
 | 
					      /* category_denylist_count= */ 1};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mediapipe::tasks::components::processors::ClassifierOptions
 | 
				
			||||||
 | 
					      cpp_classifier_options = {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.display_names_locale, "de");
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.max_results, 1);
 | 
				
			||||||
 | 
					  EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.1);
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.category_allowlist,
 | 
				
			||||||
 | 
					            std::vector<std::string>{"fruit"});
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.category_denylist,
 | 
				
			||||||
 | 
					            std::vector<std::string>{"veggies"});
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsDefaultValues) {
 | 
				
			||||||
 | 
					  std::vector<const char*> category_allowlist = {kCategoryAllowlist};
 | 
				
			||||||
 | 
					  std::vector<const char*> category_denylist = {kCategoryDenylist};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ClassifierOptions c_classifier_options = {/* display_names_locale= */ nullptr,
 | 
				
			||||||
 | 
					                                            /* max_results= */ -1,
 | 
				
			||||||
 | 
					                                            /* score_threshold= */ 0.0,
 | 
				
			||||||
 | 
					                                            /* category_allowlist= */ nullptr,
 | 
				
			||||||
 | 
					                                            /* category_allowlist_count= */ 0,
 | 
				
			||||||
 | 
					                                            /* category_denylist= */ nullptr,
 | 
				
			||||||
 | 
					                                            /* category_denylist_count= */ 0};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mediapipe::tasks::components::processors::ClassifierOptions
 | 
				
			||||||
 | 
					      cpp_classifier_options = {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.display_names_locale, "en");
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.max_results, -1);
 | 
				
			||||||
 | 
					  EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.0);
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.category_allowlist,
 | 
				
			||||||
 | 
					            std::vector<std::string>{});
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_classifier_options.category_denylist,
 | 
				
			||||||
 | 
					            std::vector<std::string>{});
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe::tasks::c::components::processors
 | 
				
			||||||
| 
						 | 
					@ -30,3 +30,15 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:base_options",
 | 
					        "//mediapipe/tasks/cc/core:base_options",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_test(
 | 
				
			||||||
 | 
					    name = "base_options_converter_test",
 | 
				
			||||||
 | 
					    srcs = ["base_options_converter_test.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":base_options",
 | 
				
			||||||
 | 
					        ":base_options_converter",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core:base_options",
 | 
				
			||||||
 | 
					        "@com_google_googletest//:gtest_main",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -23,10 +23,10 @@ extern "C" {
 | 
				
			||||||
// Base options for MediaPipe C Tasks.
 | 
					// Base options for MediaPipe C Tasks.
 | 
				
			||||||
struct BaseOptions {
 | 
					struct BaseOptions {
 | 
				
			||||||
  // The model asset file contents as a string.
 | 
					  // The model asset file contents as a string.
 | 
				
			||||||
  char* model_asset_buffer;
 | 
					  const char* model_asset_buffer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The path to the model asset to open and mmap in memory.
 | 
					  // The path to the model asset to open and mmap in memory.
 | 
				
			||||||
  char* model_asset_path;
 | 
					  const char* model_asset_path;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef __cplusplus
 | 
					#ifdef __cplusplus
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,7 +21,7 @@ limitations under the License.
 | 
				
			||||||
#include "mediapipe/tasks/c/core/base_options.h"
 | 
					#include "mediapipe/tasks/c/core/base_options.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
					#include "mediapipe/tasks/cc/core/base_options.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe::tasks::c::components::containers {
 | 
					namespace mediapipe::tasks::c::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void CppConvertToBaseOptions(const BaseOptions& in,
 | 
					void CppConvertToBaseOptions(const BaseOptions& in,
 | 
				
			||||||
                             mediapipe::tasks::core::BaseOptions* out) {
 | 
					                             mediapipe::tasks::core::BaseOptions* out) {
 | 
				
			||||||
| 
						 | 
					@ -33,4 +33,4 @@ void CppConvertToBaseOptions(const BaseOptions& in,
 | 
				
			||||||
      in.model_asset_path ? std::string(in.model_asset_path) : "";
 | 
					      in.model_asset_path ? std::string(in.model_asset_path) : "";
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace mediapipe::tasks::c::components::containers
 | 
					}  // namespace mediapipe::tasks::c::core
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,11 +19,11 @@ limitations under the License.
 | 
				
			||||||
#include "mediapipe/tasks/c/core/base_options.h"
 | 
					#include "mediapipe/tasks/c/core/base_options.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
					#include "mediapipe/tasks/cc/core/base_options.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe::tasks::c::components::containers {
 | 
					namespace mediapipe::tasks::c::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void CppConvertToBaseOptions(const BaseOptions& in,
 | 
					void CppConvertToBaseOptions(const BaseOptions& in,
 | 
				
			||||||
                             mediapipe::tasks::core::BaseOptions* out);
 | 
					                             mediapipe::tasks::core::BaseOptions* out);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace mediapipe::tasks::c::components::containers
 | 
					}  // namespace mediapipe::tasks::c::core
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif  // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
 | 
					#endif  // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										51
									
								
								mediapipe/tasks/c/core/base_options_converter_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								mediapipe/tasks/c/core/base_options_converter_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,51 @@
 | 
				
			||||||
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/c/core/base_options_converter.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/c/core/base_options.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/core/base_options.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::tasks::c::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					constexpr char kAssetBuffer[] = "abc";
 | 
				
			||||||
 | 
					constexpr char kModelAssetPath[] = "abc.tflite";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
 | 
				
			||||||
 | 
					  BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
 | 
				
			||||||
 | 
					                                /* model_asset_path= */ nullptr};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mediapipe::tasks::core::BaseOptions cpp_base_options = {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CppConvertToBaseOptions(c_base_options, &cpp_base_options);
 | 
				
			||||||
 | 
					  EXPECT_EQ(*cpp_base_options.model_asset_buffer, std::string{kAssetBuffer});
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_base_options.model_asset_path, "");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
 | 
				
			||||||
 | 
					  BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
 | 
				
			||||||
 | 
					                                /* model_asset_path= */ kModelAssetPath};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mediapipe::tasks::core::BaseOptions cpp_base_options = {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CppConvertToBaseOptions(c_base_options, &cpp_base_options);
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_base_options.model_asset_buffer.get(), nullptr);
 | 
				
			||||||
 | 
					  EXPECT_EQ(cpp_base_options.model_asset_path, std::string{kModelAssetPath});
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe::tasks::c::core
 | 
				
			||||||
| 
						 | 
					@ -30,11 +30,11 @@ namespace mediapipe::tasks::c::text::text_classifier {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions;
 | 
					 | 
				
			||||||
using ::mediapipe::tasks::c::components::containers::
 | 
					using ::mediapipe::tasks::c::components::containers::
 | 
				
			||||||
    CppConvertToClassificationResult;
 | 
					    CppConvertToClassificationResult;
 | 
				
			||||||
using ::mediapipe::tasks::c::components::processors::
 | 
					using ::mediapipe::tasks::c::components::processors::
 | 
				
			||||||
    CppConvertToClassifierOptions;
 | 
					    CppConvertToClassifierOptions;
 | 
				
			||||||
 | 
					using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
 | 
				
			||||||
using ::mediapipe::tasks::text::text_classifier::TextClassifier;
 | 
					using ::mediapipe::tasks::text::text_classifier::TextClassifier;
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user