diff --git a/mediapipe/tasks/c/vision/image_classifier/BUILD b/mediapipe/tasks/c/vision/image_classifier/BUILD index df0e636c5..e8ac090e9 100644 --- a/mediapipe/tasks/c/vision/image_classifier/BUILD +++ b/mediapipe/tasks/c/vision/image_classifier/BUILD @@ -30,13 +30,12 @@ cc_library( "//mediapipe/tasks/c/components/processors:classifier_options_converter", "//mediapipe/tasks/c/core:base_options", "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/image_classifier", "//mediapipe/tasks/cc/vision/utils:image_utils", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc index 4245ca4cd..ff6f5bdfc 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc @@ -15,6 +15,8 @@ limitations under the License. #include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h" +#include +#include #include #include @@ -26,6 +28,7 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/classification_result_converter.h" #include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" #include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" @@ -41,7 +44,10 @@ using ::mediapipe::tasks::c::components::processors:: CppConvertToClassifierOptions; using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; using ::mediapipe::tasks::vision::CreateImageFromBuffer; +using ::mediapipe::tasks::vision::core::RunningMode; using ::mediapipe::tasks::vision::image_classifier::ImageClassifier; +typedef ::mediapipe::tasks::vision::image_classifier::ImageClassifierResult + CppImageClassifierResult; int CppProcessError(absl::Status status, char** error_msg) { if (error_msg) { @@ -60,6 +66,53 @@ ImageClassifier* CppImageClassifierCreate(const ImageClassifierOptions& options, CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); CppConvertToClassifierOptions(options.classifier_options, &cpp_options->classifier_options); + cpp_options->running_mode = static_cast(options.running_mode); + + // Enable callback for processing live stream data when the running mode is + // set to RunningMode::LIVE_STREAM. + if (cpp_options->running_mode == RunningMode::LIVE_STREAM) { + if (options.result_callback == nullptr) { + const absl::Status status = absl::InvalidArgumentError( + "Provided null pointer to callback function."); + ABSL_LOG(ERROR) << "Failed to create ImageClassifier: " << status; + CppProcessError(status, error_msg); + return nullptr; + } + + ImageClassifierOptions::result_callback_fn result_callback = + options.result_callback; + cpp_options->result_callback = + [result_callback](absl::StatusOr cpp_result, + const Image& image, int64_t timestamp) { + char* error_msg = nullptr; + + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status(); + CppProcessError(cpp_result.status(), &error_msg); + result_callback(nullptr, MpImage(), timestamp, error_msg); + free(error_msg); + return; + } + + // Result is valid for the lifetime of the callback function. + ImageClassifierResult result; + CppConvertToClassificationResult(*cpp_result, &result); + + const auto& image_frame = image.GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = { + .format = static_cast<::ImageFormat>(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + result_callback(&result, mp_image, timestamp, + /* error_msg= */ nullptr); + + CppCloseClassificationResult(&result); + }; + } auto classifier = ImageClassifier::Create(std::move(cpp_options)); if (!classifier.ok()) { @@ -75,8 +128,8 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image, ImageClassifierResult* result, char** error_msg) { if (image->type == MpImage::GPU_BUFFER) { - absl::Status status = - absl::InvalidArgumentError("gpu buffer not supported yet"); + const absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet."); ABSL_LOG(ERROR) << "Classification failed: " << status.message(); return CppProcessError(status, error_msg); @@ -102,6 +155,68 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image, return 0; } +int CppImageClassifierClassifyForVideo(void* classifier, const MpImage* image, + int64_t timestamp_ms, + ImageClassifierResult* result, + char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Classification failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_classifier = static_cast(classifier); + auto cpp_result = cpp_classifier->ClassifyForVideo(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + CppConvertToClassificationResult(*cpp_result, result); + return 0; +} + +int CppImageClassifierClassifyAsync(void* classifier, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Classification failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_classifier = static_cast(classifier); + auto cpp_result = cpp_classifier->ClassifyAsync(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Data preparation for the image classification failed: " + << cpp_result; + return CppProcessError(cpp_result, error_msg); + } + return 0; +} + void CppImageClassifierCloseResult(ImageClassifierResult* result) { CppCloseClassificationResult(result); } @@ -134,6 +249,22 @@ int image_classifier_classify_image(void* classifier, const MpImage* image, CppImageClassifierClassify(classifier, image, result, error_msg); } +int image_classifier_classify_for_video(void* classifier, const MpImage* image, + int64_t timestamp_ms, + ImageClassifierResult* result, + char** error_msg) { + return mediapipe::tasks::c::vision::image_classifier:: + CppImageClassifierClassifyForVideo(classifier, image, timestamp_ms, + result, error_msg); +} + +int image_classifier_classify_async(void* classifier, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + return mediapipe::tasks::c::vision::image_classifier:: + CppImageClassifierClassifyAsync(classifier, image, timestamp_ms, + error_msg); +} + void image_classifier_close_result(ImageClassifierResult* result) { mediapipe::tasks::c::vision::image_classifier::CppImageClassifierCloseResult( result); diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier.h b/mediapipe/tasks/c/vision/image_classifier/image_classifier.h index 60dc4a2c4..549c3f300 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier.h @@ -92,9 +92,16 @@ struct ImageClassifierOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set - // to RunningMode::LIVE_STREAM. - typedef void (*result_callback_fn)(ImageClassifierResult*, const MpImage*, - int64_t); + // to RunningMode::LIVE_STREAM. Arguments of the callback function include: + // the pointer to classification result, the image that result was obtained + // on, the timestamp relevant to classification results and pointer to error + // message in case of any failure. The validity of the passed arguments is + // true for the lifetime of the callback function. + // + // A caller is responsible for closing image classifier result. + typedef void (*result_callback_fn)(ImageClassifierResult* result, + const MpImage image, int64_t timestamp_ms, + char* error_msg); result_callback_fn result_callback; }; @@ -110,13 +117,22 @@ MP_EXPORT void* image_classifier_create(struct ImageClassifierOptions* options, // If an error occurs, returns an error code and sets the error parameter to an // an error message (if `error_msg` is not nullptr). You must free the memory // allocated for the error message. -// -// TODO: Add API for video and live stream processing. MP_EXPORT int image_classifier_classify_image(void* classifier, const MpImage* image, ImageClassifierResult* result, char** error_msg = nullptr); +MP_EXPORT int image_classifier_classify_for_video(void* classifier, + const MpImage* image, + int64_t timestamp_ms, + ImageClassifierResult* result, + char** error_msg = nullptr); + +MP_EXPORT int image_classifier_classify_async(void* classifier, + const MpImage* image, + int64_t timestamp_ms, + char** error_msg = nullptr); + // Frees the memory allocated inside a ImageClassifierResult result. // Does not free the result pointer itself. MP_EXPORT void image_classifier_close_result(ImageClassifierResult* result); diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc index e8e84d864..790f5ce36 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h" +#include #include #include @@ -36,12 +37,13 @@ using testing::HasSubstr; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kModelName[] = "mobilenet_v2_1.0_224.tflite"; constexpr float kPrecision = 1e-4; +constexpr int kIterations = 100; std::string GetFullPath(absl::string_view file_name) { return JoinPath("./", kTestDataDirectory, file_name); } -TEST(ImageClassifierTest, SmokeTest) { +TEST(ImageClassifierTest, ImageModeTest) { const auto image = DecodeImageFromFile(GetFullPath("burger.jpg")); ASSERT_TRUE(image.ok()); @@ -63,14 +65,13 @@ TEST(ImageClassifierTest, SmokeTest) { void* classifier = image_classifier_create(&options); EXPECT_NE(classifier, nullptr); + const auto& image_frame = image->GetImageFrameSharedPtr(); const MpImage mp_image = { .type = MpImage::IMAGE_FRAME, - .image_frame = { - .format = static_cast( - image->GetImageFrameSharedPtr()->Format()), - .image_buffer = image->GetImageFrameSharedPtr()->PixelData(), - .width = image->GetImageFrameSharedPtr()->Width(), - .height = image->GetImageFrameSharedPtr()->Height()}}; + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; ImageClassifierResult result; image_classifier_classify_image(classifier, &mp_image, &result); @@ -84,6 +85,120 @@ TEST(ImageClassifierTest, SmokeTest) { image_classifier_close(classifier); } +TEST(ImageClassifierTest, VideoModeTest) { + const auto image = DecodeImageFromFile(GetFullPath("burger.jpg")); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + ImageClassifierOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::VIDEO, + /* classifier_options= */ + {/* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}, + /* result_callback= */ nullptr, + }; + + void* classifier = image_classifier_create(&options); + EXPECT_NE(classifier, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + ImageClassifierResult result; + image_classifier_classify_for_video(classifier, &mp_image, i, &result); + EXPECT_EQ(result.classifications_count, 1); + EXPECT_EQ(result.classifications[0].categories_count, 3); + EXPECT_EQ( + std::string{result.classifications[0].categories[0].category_name}, + "cheeseburger"); + EXPECT_NEAR(result.classifications[0].categories[0].score, 0.7939f, + kPrecision); + image_classifier_close_result(&result); + } + image_classifier_close(classifier); +} + +// A structure to support LiveStreamModeTest below. This structure holds a +// static method `Fn` for a callback function of C API. A `static` qualifier +// allows to take an address of the method to follow API style. Another static +// struct member is `last_timestamp` that is used to verify that current +// timestamp is greater than the previous one. +struct LiveStreamModeCallback { + static int64_t last_timestamp; + static void Fn(ImageClassifierResult* classifier_result, const MpImage image, + int64_t timestamp, char* error_msg) { + ASSERT_NE(classifier_result, nullptr); + ASSERT_EQ(error_msg, nullptr); + EXPECT_EQ( + std::string{ + classifier_result->classifications[0].categories[0].category_name}, + "cheeseburger"); + EXPECT_NEAR(classifier_result->classifications[0].categories[0].score, + 0.7939f, kPrecision); + EXPECT_GT(image.image_frame.width, 0); + EXPECT_GT(image.image_frame.height, 0); + EXPECT_GT(timestamp, last_timestamp); + last_timestamp++; + } +}; +int64_t LiveStreamModeCallback::last_timestamp = -1; + +TEST(ImageClassifierTest, LiveStreamModeTest) { + const auto image = DecodeImageFromFile(GetFullPath("burger.jpg")); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + + ImageClassifierOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::LIVE_STREAM, + /* classifier_options= */ + {/* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}, + /* result_callback= */ LiveStreamModeCallback::Fn, + }; + + void* classifier = image_classifier_create(&options); + EXPECT_NE(classifier, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + EXPECT_GE(image_classifier_classify_async(classifier, &mp_image, i), 0); + } + image_classifier_close(classifier); + + // Due to the flow limiter, the total of outputs might be smaller than the + // number of iterations. + EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations); + EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0); +} + TEST(ImageClassifierTest, InvalidArgumentHandling) { // It is an error to set neither the asset buffer nor the path. ImageClassifierOptions options = { @@ -124,7 +239,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) { ImageClassifierResult result; char* error_msg; image_classifier_classify_image(classifier, &mp_image, &result, &error_msg); - EXPECT_THAT(error_msg, HasSubstr("gpu buffer not supported yet")); + EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet")); free(error_msg); image_classifier_close(classifier); }