Merge branch 'master' into c-image-embedder-api

This commit is contained in:
Kinar R 2023-11-07 20:37:09 +05:30 committed by GitHub
commit 42a916ad4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 74 additions and 61 deletions

View File

@ -516,7 +516,6 @@ cc_library(
":gpu_buffer_storage", ":gpu_buffer_storage",
":image_frame_view", ":image_frame_view",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],
) )
@ -1224,13 +1223,9 @@ mediapipe_cc_test(
], ],
requires_full_emulation = True, requires_full_emulation = True,
deps = [ deps = [
":gl_texture_buffer",
":gl_texture_util",
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage_ahwb", ":gpu_buffer_storage_ahwb",
":gpu_test_base",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/tool:test_util",
], ],
) )

View File

@ -238,7 +238,7 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
case GpuBufferFormat::kRGBAFloat128: case GpuBufferFormat::kRGBAFloat128:
return ImageFormat::VEC32F4; return ImageFormat::VEC32F4;
case GpuBufferFormat::kRGBA32: case GpuBufferFormat::kRGBA32:
return ImageFormat::SRGBA; // TODO: this likely maps to ImageFormat::SRGBA
case GpuBufferFormat::kGrayHalf16: case GpuBufferFormat::kGrayHalf16:
case GpuBufferFormat::kOneComponent8Alpha: case GpuBufferFormat::kOneComponent8Alpha:
case GpuBufferFormat::kOneComponent8Red: case GpuBufferFormat::kOneComponent8Red:

View File

@ -22,9 +22,12 @@ extern "C" {
// Base options for MediaPipe C Tasks. // Base options for MediaPipe C Tasks.
struct BaseOptions { struct BaseOptions {
// The model asset file contents as a string. // The model asset file contents as bytes.
const char* model_asset_buffer; const char* model_asset_buffer;
// The size of the model assets buffer (or `0` if not set).
const unsigned int model_asset_buffer_count;
// The path to the model asset to open and mmap in memory. // The path to the model asset to open and mmap in memory.
const char* model_asset_path; const char* model_asset_path;
}; };

View File

@ -27,7 +27,9 @@ void CppConvertToBaseOptions(const BaseOptions& in,
mediapipe::tasks::core::BaseOptions* out) { mediapipe::tasks::core::BaseOptions* out) {
out->model_asset_buffer = out->model_asset_buffer =
in.model_asset_buffer in.model_asset_buffer
? std::make_unique<std::string>(in.model_asset_buffer) ? std::make_unique<std::string>(
in.model_asset_buffer,
in.model_asset_buffer + in.model_asset_buffer_count)
: nullptr; : nullptr;
out->model_asset_path = out->model_asset_path =
in.model_asset_path ? std::string(in.model_asset_path) : ""; in.model_asset_path ? std::string(in.model_asset_path) : "";

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/c/core/base_options_converter.h" #include "mediapipe/tasks/c/core/base_options_converter.h"
#include <cstring>
#include <string> #include <string>
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -28,6 +29,8 @@ constexpr char kModelAssetPath[] = "abc.tflite";
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) { TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer, BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
/* model_asset_buffer_count= */
static_cast<unsigned int>(strlen(kAssetBuffer)),
/* model_asset_path= */ nullptr}; /* model_asset_path= */ nullptr};
mediapipe::tasks::core::BaseOptions cpp_base_options = {}; mediapipe::tasks::core::BaseOptions cpp_base_options = {};
@ -39,6 +42,7 @@ TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) { TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr, BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ kModelAssetPath}; /* model_asset_path= */ kModelAssetPath};
mediapipe::tasks::core::BaseOptions cpp_base_options = {}; mediapipe::tasks::core::BaseOptions cpp_base_options = {};

View File

@ -60,18 +60,18 @@ struct LanguageDetectorOptions {
// Creates a LanguageDetector from the provided `options`. // Creates a LanguageDetector from the provided `options`.
// Returns a pointer to the language detector on success. // Returns a pointer to the language detector on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an // If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT void* language_detector_create( MP_EXPORT void* language_detector_create(
struct LanguageDetectorOptions* options, char** error_msg = nullptr); struct LanguageDetectorOptions* options, char** error_msg);
// Performs language detection on the input `text`. Returns `0` on success. // Performs language detection on the input `text`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int language_detector_detect(void* detector, const char* utf8_str, MP_EXPORT int language_detector_detect(void* detector, const char* utf8_str,
LanguageDetectorResult* result, LanguageDetectorResult* result,
char** error_msg = nullptr); char** error_msg);
// Frees the memory allocated inside a LanguageDetectorResult result. Does not // Frees the memory allocated inside a LanguageDetectorResult result. Does not
// free the result pointer itself. // free the result pointer itself.
@ -79,10 +79,9 @@ MP_EXPORT void language_detector_close_result(LanguageDetectorResult* result);
// Shuts down the LanguageDetector when all the work is done. Frees all memory. // Shuts down the LanguageDetector when all the work is done. Frees all memory.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int language_detector_close(void* detector, MP_EXPORT int language_detector_close(void* detector, char** error_msg);
char** error_msg = nullptr);
#ifdef __cplusplus #ifdef __cplusplus
} // extern C } // extern C

View File

@ -44,6 +44,7 @@ TEST(LanguageDetectorTest, SmokeTest) {
std::string model_path = GetFullPath(kTestLanguageDetectorModelPath); std::string model_path = GetFullPath(kTestLanguageDetectorModelPath);
LanguageDetectorOptions options = { LanguageDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* classifier_options= */ /* classifier_options= */
{/* display_names_locale= */ nullptr, {/* display_names_locale= */ nullptr,
@ -55,22 +56,24 @@ TEST(LanguageDetectorTest, SmokeTest) {
/* category_denylist_count= */ 0}, /* category_denylist_count= */ 0},
}; };
void* detector = language_detector_create(&options); void* detector = language_detector_create(&options, /* error_msg */ nullptr);
EXPECT_NE(detector, nullptr); EXPECT_NE(detector, nullptr);
LanguageDetectorResult result; LanguageDetectorResult result;
language_detector_detect(detector, kTestString, &result); language_detector_detect(detector, kTestString, &result,
/* error_msg */ nullptr);
EXPECT_EQ(std::string(result.predictions[0].language_code), "fr"); EXPECT_EQ(std::string(result.predictions[0].language_code), "fr");
EXPECT_NEAR(result.predictions[0].probability, 0.999781, kPrecision); EXPECT_NEAR(result.predictions[0].probability, 0.999781, kPrecision);
language_detector_close_result(&result); language_detector_close_result(&result);
language_detector_close(detector); language_detector_close(detector, /* error_msg */ nullptr);
} }
TEST(LanguageDetectorTest, ErrorHandling) { TEST(LanguageDetectorTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path. // It is an error to set neither the asset buffer nor the path.
LanguageDetectorOptions options = { LanguageDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr}, /* model_asset_path= */ nullptr},
/* classifier_options= */ {}, /* classifier_options= */ {},
}; };

View File

@ -44,18 +44,18 @@ struct TextClassifierOptions {
// Creates a TextClassifier from the provided `options`. // Creates a TextClassifier from the provided `options`.
// Returns a pointer to the text classifier on success. // Returns a pointer to the text classifier on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an // If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options, MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options,
char** error_msg = nullptr); char** error_msg);
// Performs classification on the input `text`. Returns `0` on success. // Performs classification on the input `text`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str, MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str,
TextClassifierResult* result, TextClassifierResult* result,
char** error_msg = nullptr); char** error_msg);
// Frees the memory allocated inside a TextClassifierResult result. Does not // Frees the memory allocated inside a TextClassifierResult result. Does not
// free the result pointer itself. // free the result pointer itself.
@ -63,10 +63,9 @@ MP_EXPORT void text_classifier_close_result(TextClassifierResult* result);
// Shuts down the TextClassifier when all the work is done. Frees all memory. // Shuts down the TextClassifier when all the work is done. Frees all memory.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int text_classifier_close(void* classifier, MP_EXPORT int text_classifier_close(void* classifier, char** error_msg);
char** error_msg = nullptr);
#ifdef __cplusplus #ifdef __cplusplus
} // extern C } // extern C

View File

@ -43,6 +43,7 @@ TEST(TextClassifierTest, SmokeTest) {
std::string model_path = GetFullPath(kTestBertModelPath); std::string model_path = GetFullPath(kTestBertModelPath);
TextClassifierOptions options = { TextClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* classifier_options= */ /* classifier_options= */
{/* display_names_locale= */ nullptr, {/* display_names_locale= */ nullptr,
@ -54,11 +55,12 @@ TEST(TextClassifierTest, SmokeTest) {
/* category_denylist_count= */ 0}, /* category_denylist_count= */ 0},
}; };
void* classifier = text_classifier_create(&options); void* classifier = text_classifier_create(&options, /* error_msg */ nullptr);
EXPECT_NE(classifier, nullptr); EXPECT_NE(classifier, nullptr);
TextClassifierResult result; TextClassifierResult result;
text_classifier_classify(classifier, kTestString, &result); text_classifier_classify(classifier, kTestString, &result,
/* error_msg */ nullptr);
EXPECT_EQ(result.classifications_count, 1); EXPECT_EQ(result.classifications_count, 1);
EXPECT_EQ(result.classifications[0].categories_count, 2); EXPECT_EQ(result.classifications[0].categories_count, 2);
EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name}, EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name},
@ -67,13 +69,14 @@ TEST(TextClassifierTest, SmokeTest) {
kPrecision); kPrecision);
text_classifier_close_result(&result); text_classifier_close_result(&result);
text_classifier_close(classifier); text_classifier_close(classifier, /* error_msg */ nullptr);
} }
TEST(TextClassifierTest, ErrorHandling) { TEST(TextClassifierTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path. // It is an error to set neither the asset buffer nor the path.
TextClassifierOptions options = { TextClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr}, /* model_asset_path= */ nullptr},
/* classifier_options= */ {}, /* classifier_options= */ {},
}; };

View File

@ -47,15 +47,14 @@ struct TextEmbedderOptions {
// an error message (if `error_msg` is not `nullptr`). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options, MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
char** error_msg = nullptr); char** error_msg);
// Performs embedding extraction on the input `text`. Returns `0` on success. // Performs embedding extraction on the input `text`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str, MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
TextEmbedderResult* result, TextEmbedderResult* result, char** error_msg);
char** error_msg = nullptr);
// Frees the memory allocated inside a TextEmbedderResult result. Does not // Frees the memory allocated inside a TextEmbedderResult result. Does not
// free the result pointer itself. // free the result pointer itself.
@ -65,7 +64,7 @@ MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int text_embedder_close(void* embedder, char** error_msg = nullptr); MP_EXPORT int text_embedder_close(void* embedder, char** error_msg);
// Utility function to compute cosine similarity [1] between two embeddings. // Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different // May return an InvalidArgumentError if e.g. the embeddings are of different

View File

@ -47,21 +47,22 @@ TEST(TextEmbedderTest, SmokeTest) {
std::string model_path = GetFullPath(kTestBertModelPath); std::string model_path = GetFullPath(kTestBertModelPath);
TextEmbedderOptions options = { TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* embedder_options= */ /* embedder_options= */
{/* l2_normalize= */ false, /* quantize= */ true}, {/* l2_normalize= */ false, /* quantize= */ true},
}; };
void* embedder = text_embedder_create(&options); void* embedder = text_embedder_create(&options, /* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr); EXPECT_NE(embedder, nullptr);
TextEmbedderResult result; TextEmbedderResult result;
text_embedder_embed(embedder, kTestString0, &result); text_embedder_embed(embedder, kTestString0, &result, /* error_msg */ nullptr);
EXPECT_EQ(result.embeddings_count, 1); EXPECT_EQ(result.embeddings_count, 1);
EXPECT_EQ(result.embeddings[0].values_count, 512); EXPECT_EQ(result.embeddings[0].values_count, 512);
text_embedder_close_result(&result); text_embedder_close_result(&result);
text_embedder_close(embedder); text_embedder_close(embedder, /* error_msg */ nullptr);
} }
TEST(TextEmbedderTest, SucceedsWithCosineSimilarity) { TEST(TextEmbedderTest, SucceedsWithCosineSimilarity) {
@ -78,9 +79,9 @@ TEST(TextEmbedderTest, SucceedsWithCosineSimilarity) {
// Extract both embeddings. // Extract both embeddings.
TextEmbedderResult result0; TextEmbedderResult result0;
text_embedder_embed(embedder, kTestString0, &result0); text_embedder_embed(embedder, kTestString0, &result0, /* error_msg */ nullptr);
TextEmbedderResult result1; TextEmbedderResult result1;
text_embedder_embed(embedder, kTestString1, &result1); text_embedder_embed(embedder, kTestString1, &result1, /* error_msg */ nullptr);
// Check cosine similarity. // Check cosine similarity.
double similarity; double similarity;
@ -95,6 +96,7 @@ TEST(TextEmbedderTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path. // It is an error to set neither the asset buffer nor the path.
TextEmbedderOptions options = { TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr}, /* model_asset_path= */ nullptr},
/* embedder_options= */ {}, /* embedder_options= */ {},
}; };

View File

@ -55,11 +55,7 @@ cc_test(
":image_classifier_lib", ":image_classifier_lib",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:gtest", "//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/c/components/containers:category", "//mediapipe/tasks/c/components/containers:category",
"//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",

View File

@ -108,30 +108,30 @@ struct ImageClassifierOptions {
// Creates an ImageClassifier from provided `options`. // Creates an ImageClassifier from provided `options`.
// Returns a pointer to the image classifier on success. // Returns a pointer to the image classifier on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an // If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT void* image_classifier_create(struct ImageClassifierOptions* options, MP_EXPORT void* image_classifier_create(struct ImageClassifierOptions* options,
char** error_msg = nullptr); char** error_msg);
// Performs image classification on the input `image`. Returns `0` on success. // Performs image classification on the input `image`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int image_classifier_classify_image(void* classifier, MP_EXPORT int image_classifier_classify_image(void* classifier,
const MpImage* image, const MpImage* image,
ImageClassifierResult* result, ImageClassifierResult* result,
char** error_msg = nullptr); char** error_msg);
MP_EXPORT int image_classifier_classify_for_video(void* classifier, MP_EXPORT int image_classifier_classify_for_video(void* classifier,
const MpImage* image, const MpImage* image,
int64_t timestamp_ms, int64_t timestamp_ms,
ImageClassifierResult* result, ImageClassifierResult* result,
char** error_msg = nullptr); char** error_msg);
MP_EXPORT int image_classifier_classify_async(void* classifier, MP_EXPORT int image_classifier_classify_async(void* classifier,
const MpImage* image, const MpImage* image,
int64_t timestamp_ms, int64_t timestamp_ms,
char** error_msg = nullptr); char** error_msg);
// Frees the memory allocated inside a ImageClassifierResult result. // Frees the memory allocated inside a ImageClassifierResult result.
// Does not free the result pointer itself. // Does not free the result pointer itself.
@ -139,10 +139,9 @@ MP_EXPORT void image_classifier_close_result(ImageClassifierResult* result);
// Frees image classifier. // Frees image classifier.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int image_classifier_close(void* classifier, MP_EXPORT int image_classifier_close(void* classifier, char** error_msg);
char** error_msg = nullptr);
#ifdef __cplusplus #ifdef __cplusplus
} // extern C } // extern C

View File

@ -50,6 +50,7 @@ TEST(ImageClassifierTest, ImageModeTest) {
const std::string model_path = GetFullPath(kModelName); const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = { ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE, /* running_mode= */ RunningMode::IMAGE,
/* classifier_options= */ /* classifier_options= */
@ -62,7 +63,7 @@ TEST(ImageClassifierTest, ImageModeTest) {
/* category_denylist_count= */ 0}, /* category_denylist_count= */ 0},
}; };
void* classifier = image_classifier_create(&options); void* classifier = image_classifier_create(&options, /* error_msg */ nullptr);
EXPECT_NE(classifier, nullptr); EXPECT_NE(classifier, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr(); const auto& image_frame = image->GetImageFrameSharedPtr();
@ -74,7 +75,8 @@ TEST(ImageClassifierTest, ImageModeTest) {
.height = image_frame->Height()}}; .height = image_frame->Height()}};
ImageClassifierResult result; ImageClassifierResult result;
image_classifier_classify_image(classifier, &mp_image, &result); image_classifier_classify_image(classifier, &mp_image, &result,
/* error_msg */ nullptr);
EXPECT_EQ(result.classifications_count, 1); EXPECT_EQ(result.classifications_count, 1);
EXPECT_EQ(result.classifications[0].categories_count, 1001); EXPECT_EQ(result.classifications[0].categories_count, 1001);
EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name}, EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name},
@ -82,7 +84,7 @@ TEST(ImageClassifierTest, ImageModeTest) {
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.7939f, EXPECT_NEAR(result.classifications[0].categories[0].score, 0.7939f,
kPrecision); kPrecision);
image_classifier_close_result(&result); image_classifier_close_result(&result);
image_classifier_close(classifier); image_classifier_close(classifier, /* error_msg */ nullptr);
} }
TEST(ImageClassifierTest, VideoModeTest) { TEST(ImageClassifierTest, VideoModeTest) {
@ -92,6 +94,7 @@ TEST(ImageClassifierTest, VideoModeTest) {
const std::string model_path = GetFullPath(kModelName); const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = { ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::VIDEO, /* running_mode= */ RunningMode::VIDEO,
/* classifier_options= */ /* classifier_options= */
@ -105,7 +108,7 @@ TEST(ImageClassifierTest, VideoModeTest) {
/* result_callback= */ nullptr, /* result_callback= */ nullptr,
}; };
void* classifier = image_classifier_create(&options); void* classifier = image_classifier_create(&options, /* error_msg */ nullptr);
EXPECT_NE(classifier, nullptr); EXPECT_NE(classifier, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr(); const auto& image_frame = image->GetImageFrameSharedPtr();
@ -118,7 +121,8 @@ TEST(ImageClassifierTest, VideoModeTest) {
for (int i = 0; i < kIterations; ++i) { for (int i = 0; i < kIterations; ++i) {
ImageClassifierResult result; ImageClassifierResult result;
image_classifier_classify_for_video(classifier, &mp_image, i, &result); image_classifier_classify_for_video(classifier, &mp_image, i, &result,
/* error_msg */ nullptr);
EXPECT_EQ(result.classifications_count, 1); EXPECT_EQ(result.classifications_count, 1);
EXPECT_EQ(result.classifications[0].categories_count, 3); EXPECT_EQ(result.classifications[0].categories_count, 3);
EXPECT_EQ( EXPECT_EQ(
@ -128,7 +132,7 @@ TEST(ImageClassifierTest, VideoModeTest) {
kPrecision); kPrecision);
image_classifier_close_result(&result); image_classifier_close_result(&result);
} }
image_classifier_close(classifier); image_classifier_close(classifier, /* error_msg */ nullptr);
} }
// A structure to support LiveStreamModeTest below. This structure holds a // A structure to support LiveStreamModeTest below. This structure holds a
@ -164,6 +168,7 @@ TEST(ImageClassifierTest, LiveStreamModeTest) {
ImageClassifierOptions options = { ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::LIVE_STREAM, /* running_mode= */ RunningMode::LIVE_STREAM,
/* classifier_options= */ /* classifier_options= */
@ -177,7 +182,7 @@ TEST(ImageClassifierTest, LiveStreamModeTest) {
/* result_callback= */ LiveStreamModeCallback::Fn, /* result_callback= */ LiveStreamModeCallback::Fn,
}; };
void* classifier = image_classifier_create(&options); void* classifier = image_classifier_create(&options, /* error_msg */ nullptr);
EXPECT_NE(classifier, nullptr); EXPECT_NE(classifier, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr(); const auto& image_frame = image->GetImageFrameSharedPtr();
@ -189,9 +194,11 @@ TEST(ImageClassifierTest, LiveStreamModeTest) {
.height = image_frame->Height()}}; .height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) { for (int i = 0; i < kIterations; ++i) {
EXPECT_GE(image_classifier_classify_async(classifier, &mp_image, i), 0); EXPECT_GE(image_classifier_classify_async(classifier, &mp_image, i,
/* error_msg */ nullptr),
0);
} }
image_classifier_close(classifier); image_classifier_close(classifier, /* error_msg */ nullptr);
// Due to the flow limiter, the total of outputs might be smaller than the // Due to the flow limiter, the total of outputs might be smaller than the
// number of iterations. // number of iterations.
@ -203,6 +210,7 @@ TEST(ImageClassifierTest, InvalidArgumentHandling) {
// It is an error to set neither the asset buffer nor the path. // It is an error to set neither the asset buffer nor the path.
ImageClassifierOptions options = { ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr}, /* model_asset_path= */ nullptr},
/* classifier_options= */ {}, /* classifier_options= */ {},
}; };
@ -220,6 +228,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) {
const std::string model_path = GetFullPath(kModelName); const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = { ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE, /* running_mode= */ RunningMode::IMAGE,
/* classifier_options= */ /* classifier_options= */
@ -232,7 +241,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) {
/* category_denylist_count= */ 0}, /* category_denylist_count= */ 0},
}; };
void* classifier = image_classifier_create(&options); void* classifier = image_classifier_create(&options, /* error_msg */ nullptr);
EXPECT_NE(classifier, nullptr); EXPECT_NE(classifier, nullptr);
const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}}; const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
@ -241,7 +250,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) {
image_classifier_classify_image(classifier, &mp_image, &result, &error_msg); image_classifier_classify_image(classifier, &mp_image, &result, &error_msg);
EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet")); EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet"));
free(error_msg); free(error_msg);
image_classifier_close(classifier); image_classifier_close(classifier, /* error_msg */ nullptr);
} }
} // namespace } // namespace