303 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			303 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* Copyright 2023 The MediaPipe Authors.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ==============================================================================*/
 | |
| 
 | |
| #include "mediapipe/tasks/c/vision/image_embedder/image_embedder.h"
 | |
| 
 | |
| #include <cstdint>
 | |
| #include <cstdlib>
 | |
| #include <string>
 | |
| 
 | |
| #include "absl/flags/flag.h"
 | |
| #include "absl/strings/string_view.h"
 | |
| #include "mediapipe/framework/deps/file_path.h"
 | |
| #include "mediapipe/framework/formats/image.h"
 | |
| #include "mediapipe/framework/port/gmock.h"
 | |
| #include "mediapipe/framework/port/gtest.h"
 | |
| #include "mediapipe/tasks/c/vision/core/common.h"
 | |
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
 | |
| 
 | |
| namespace {
 | |
| 
 | |
| using ::mediapipe::file::JoinPath;
 | |
| using ::mediapipe::tasks::vision::DecodeImageFromFile;
 | |
| using testing::HasSubstr;
 | |
| 
 | |
| constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
 | |
| constexpr char kModelName[] = "mobilenet_v3_small_100_224_embedder.tflite";
 | |
| constexpr char kImageFile[] = "burger.jpg";
 | |
| constexpr float kPrecision = 1e-6;
 | |
| constexpr int kIterations = 100;
 | |
| 
 | |
| std::string GetFullPath(absl::string_view file_name) {
 | |
|   return JoinPath("./", kTestDataDirectory, file_name);
 | |
| }
 | |
| 
 | |
| // Utility function to check the sizes, head_index and head_names of a result
 | |
| // produced by kMobileNetV3Embedder.
 | |
| void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
 | |
|   EXPECT_EQ(result.embeddings_count, 1);
 | |
|   EXPECT_EQ(result.embeddings[0].head_index, 0);
 | |
|   EXPECT_EQ(std::string{result.embeddings[0].head_name}, "feature");
 | |
|   if (quantized) {
 | |
|     EXPECT_EQ(result.embeddings[0].values_count, 1024);
 | |
|   } else {
 | |
|     EXPECT_EQ(result.embeddings[0].values_count, 1024);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(ImageEmbedderTest, ImageModeTest) {
 | |
|   const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
 | |
|   ASSERT_TRUE(image.ok());
 | |
| 
 | |
|   const std::string model_path = GetFullPath(kModelName);
 | |
|   ImageEmbedderOptions options = {
 | |
|       /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | |
|                            /* model_asset_buffer_count= */ 0,
 | |
|                            /* model_asset_path= */ model_path.c_str()},
 | |
|       /* running_mode= */ RunningMode::IMAGE,
 | |
|       /* embedder_options= */
 | |
|       {/* l2_normalize= */ true,
 | |
|        /* quantize= */ false}};
 | |
| 
 | |
|   void* embedder = image_embedder_create(&options,
 | |
|                                          /* error_msg */ nullptr);
 | |
|   EXPECT_NE(embedder, nullptr);
 | |
| 
 | |
|   const MpImage mp_image = {
 | |
|       .type = MpImage::IMAGE_FRAME,
 | |
|       .image_frame = {
 | |
|           .format = static_cast<ImageFormat>(
 | |
|               image->GetImageFrameSharedPtr()->Format()),
 | |
|           .image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
 | |
|           .width = image->GetImageFrameSharedPtr()->Width(),
 | |
|           .height = image->GetImageFrameSharedPtr()->Height()}};
 | |
| 
 | |
|   ImageEmbedderResult result;
 | |
|   image_embedder_embed_image(embedder, &mp_image, &result,
 | |
|                              /* error_msg */ nullptr);
 | |
|   CheckMobileNetV3Result(result, false);
 | |
|   EXPECT_NEAR(result.embeddings[0].float_embedding[0], -0.0142344, kPrecision);
 | |
|   image_embedder_close_result(&result);
 | |
|   image_embedder_close(embedder, /* error_msg */ nullptr);
 | |
| }
 | |
| 
 | |
| TEST(ImageEmbedderTest, SucceedsWithCosineSimilarity) {
 | |
|   const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
 | |
|   ASSERT_TRUE(image.ok());
 | |
|   const auto crop = DecodeImageFromFile(GetFullPath("burger_crop.jpg"));
 | |
|   ASSERT_TRUE(crop.ok());
 | |
| 
 | |
|   const std::string model_path = GetFullPath(kModelName);
 | |
|   ImageEmbedderOptions options = {
 | |
|       /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | |
|                            /* model_asset_buffer_count= */ 0,
 | |
|                            /* model_asset_path= */ model_path.c_str()},
 | |
|       /* running_mode= */ RunningMode::IMAGE,
 | |
|       /* embedder_options= */
 | |
|       {/* l2_normalize= */ true,
 | |
|        /* quantize= */ false}};
 | |
| 
 | |
|   void* embedder = image_embedder_create(&options,
 | |
|                                          /* error_msg */ nullptr);
 | |
|   EXPECT_NE(embedder, nullptr);
 | |
| 
 | |
|   const MpImage mp_image = {
 | |
|       .type = MpImage::IMAGE_FRAME,
 | |
|       .image_frame = {
 | |
|           .format = static_cast<ImageFormat>(
 | |
|               image->GetImageFrameSharedPtr()->Format()),
 | |
|           .image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
 | |
|           .width = image->GetImageFrameSharedPtr()->Width(),
 | |
|           .height = image->GetImageFrameSharedPtr()->Height()}};
 | |
| 
 | |
|   const MpImage mp_crop = {
 | |
|       .type = MpImage::IMAGE_FRAME,
 | |
|       .image_frame = {
 | |
|           .format = static_cast<ImageFormat>(
 | |
|               crop->GetImageFrameSharedPtr()->Format()),
 | |
|           .image_buffer = crop->GetImageFrameSharedPtr()->PixelData(),
 | |
|           .width = crop->GetImageFrameSharedPtr()->Width(),
 | |
|           .height = crop->GetImageFrameSharedPtr()->Height()}};
 | |
| 
 | |
|   // Extract both embeddings.
 | |
|   ImageEmbedderResult image_result;
 | |
|   image_embedder_embed_image(embedder, &mp_image, &image_result,
 | |
|                              /* error_msg */ nullptr);
 | |
|   ImageEmbedderResult crop_result;
 | |
|   image_embedder_embed_image(embedder, &mp_crop, &crop_result,
 | |
|                              /* error_msg */ nullptr);
 | |
| 
 | |
|   // Check results.
 | |
|   CheckMobileNetV3Result(image_result, false);
 | |
|   CheckMobileNetV3Result(crop_result, false);
 | |
|   // Check cosine similarity.
 | |
|   double similarity;
 | |
|   image_embedder_cosine_similarity(image_result.embeddings[0],
 | |
|                                    crop_result.embeddings[0], &similarity,
 | |
|                                    /* error_msg */ nullptr);
 | |
|   double expected_similarity = 0.925519;
 | |
|   EXPECT_LE(abs(similarity - expected_similarity), kPrecision);
 | |
|   image_embedder_close_result(&image_result);
 | |
|   image_embedder_close_result(&crop_result);
 | |
|   image_embedder_close(embedder, /* error_msg */ nullptr);
 | |
| }
 | |
| 
 | |
| TEST(ImageEmbedderTest, VideoModeTest) {
 | |
|   const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
 | |
|   ASSERT_TRUE(image.ok());
 | |
| 
 | |
|   const std::string model_path = GetFullPath(kModelName);
 | |
|   ImageEmbedderOptions options = {
 | |
|       /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | |
|                            /* model_asset_buffer_count= */ 0,
 | |
|                            /* model_asset_path= */ model_path.c_str()},
 | |
|       /* running_mode= */ RunningMode::VIDEO,
 | |
|       /* embedder_options= */
 | |
|       {/* l2_normalize= */ true,
 | |
|        /* quantize= */ false}};
 | |
| 
 | |
|   void* embedder = image_embedder_create(&options,
 | |
|                                          /* error_msg */ nullptr);
 | |
|   EXPECT_NE(embedder, nullptr);
 | |
| 
 | |
|   const auto& image_frame = image->GetImageFrameSharedPtr();
 | |
|   const MpImage mp_image = {
 | |
|       .type = MpImage::IMAGE_FRAME,
 | |
|       .image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
 | |
|                       .image_buffer = image_frame->PixelData(),
 | |
|                       .width = image_frame->Width(),
 | |
|                       .height = image_frame->Height()}};
 | |
| 
 | |
|   for (int i = 0; i < kIterations; ++i) {
 | |
|     ImageEmbedderResult result;
 | |
|     image_embedder_embed_for_video(embedder, &mp_image, i, &result,
 | |
|                                    /* error_msg */ nullptr);
 | |
|     CheckMobileNetV3Result(result, false);
 | |
|     EXPECT_NEAR(result.embeddings[0].float_embedding[0], -0.0142344,
 | |
|                 kPrecision);
 | |
|     image_embedder_close_result(&result);
 | |
|   }
 | |
|   image_embedder_close(embedder, /* error_msg */ nullptr);
 | |
| }
 | |
| 
 | |
| // A structure to support LiveStreamModeTest below. This structure holds a
 | |
| // static method `Fn` for a callback function of C API. A `static` qualifier
 | |
| // allows to take an address of the method to follow API style. Another static
 | |
| // struct member is `last_timestamp` that is used to verify that current
 | |
| // timestamp is greater than the previous one.
 | |
| struct LiveStreamModeCallback {
 | |
|   static int64_t last_timestamp;
 | |
|   static void Fn(ImageEmbedderResult* embedder_result, const MpImage image,
 | |
|                  int64_t timestamp, char* error_msg) {
 | |
|     ASSERT_NE(embedder_result, nullptr);
 | |
|     ASSERT_EQ(error_msg, nullptr);
 | |
|     CheckMobileNetV3Result(*embedder_result, false);
 | |
|     EXPECT_NEAR(embedder_result->embeddings[0].float_embedding[0], -0.0142344,
 | |
|                 kPrecision);
 | |
|     EXPECT_GT(image.image_frame.width, 0);
 | |
|     EXPECT_GT(image.image_frame.height, 0);
 | |
|     EXPECT_GT(timestamp, last_timestamp);
 | |
|     last_timestamp++;
 | |
|   }
 | |
| };
 | |
| int64_t LiveStreamModeCallback::last_timestamp = -1;
 | |
| 
 | |
| TEST(ImageEmbedderTest, LiveStreamModeTest) {
 | |
|   const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
 | |
|   ASSERT_TRUE(image.ok());
 | |
| 
 | |
|   const std::string model_path = GetFullPath(kModelName);
 | |
| 
 | |
|   ImageEmbedderOptions options = {
 | |
|       /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | |
|                            /* model_asset_buffer_count= */ 0,
 | |
|                            /* model_asset_path= */ model_path.c_str()},
 | |
|       /* running_mode= */ RunningMode::LIVE_STREAM,
 | |
|       /* embedder_options= */
 | |
|       {/* l2_normalize= */ true,
 | |
|        /* quantize= */ false},
 | |
|       /* result_callback= */ LiveStreamModeCallback::Fn,
 | |
|   };
 | |
| 
 | |
|   void* embedder = image_embedder_create(&options,
 | |
|                                          /* error_msg */ nullptr);
 | |
|   EXPECT_NE(embedder, nullptr);
 | |
| 
 | |
|   const auto& image_frame = image->GetImageFrameSharedPtr();
 | |
|   const MpImage mp_image = {
 | |
|       .type = MpImage::IMAGE_FRAME,
 | |
|       .image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
 | |
|                       .image_buffer = image_frame->PixelData(),
 | |
|                       .width = image_frame->Width(),
 | |
|                       .height = image_frame->Height()}};
 | |
| 
 | |
|   for (int i = 0; i < kIterations; ++i) {
 | |
|     EXPECT_GE(image_embedder_embed_async(embedder, &mp_image, i,
 | |
|                                          /* error_msg */ nullptr),
 | |
|               0);
 | |
|   }
 | |
|   image_embedder_close(embedder, /* error_msg */ nullptr);
 | |
| 
 | |
|   // Due to the flow limiter, the total of outputs might be smaller than the
 | |
|   // number of iterations.
 | |
|   EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
 | |
|   EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
 | |
| }
 | |
| 
 | |
| TEST(ImageEmbedderTest, InvalidArgumentHandling) {
 | |
|   // It is an error to set neither the asset buffer nor the path.
 | |
|   ImageEmbedderOptions options = {
 | |
|       /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | |
|                            /* model_asset_buffer_count= */ 0,
 | |
|                            /* model_asset_path= */ nullptr},
 | |
|       /* embedder_options= */ {},
 | |
|   };
 | |
| 
 | |
|   char* error_msg;
 | |
|   void* embedder = image_embedder_create(&options, &error_msg);
 | |
|   EXPECT_EQ(embedder, nullptr);
 | |
| 
 | |
|   EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify"));
 | |
| 
 | |
|   free(error_msg);
 | |
| }
 | |
| 
 | |
| TEST(ImageEmbedderTest, FailedEmbeddingHandling) {
 | |
|   const std::string model_path = GetFullPath(kModelName);
 | |
|   ImageEmbedderOptions options = {
 | |
|       /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | |
|                            /* model_asset_buffer_count= */ 0,
 | |
|                            /* model_asset_path= */ model_path.c_str()},
 | |
|       /* running_mode= */ RunningMode::IMAGE,
 | |
|       /* embedder_options= */
 | |
|       {/* l2_normalize= */ false,
 | |
|        /* quantize= */ false},
 | |
|   };
 | |
| 
 | |
|   void* embedder = image_embedder_create(&options,
 | |
|                                          /* error_msg */ nullptr);
 | |
|   EXPECT_NE(embedder, nullptr);
 | |
| 
 | |
|   const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
 | |
|   ImageEmbedderResult result;
 | |
|   char* error_msg;
 | |
|   image_embedder_embed_image(embedder, &mp_image, &result, &error_msg);
 | |
|   EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet."));
 | |
|   free(error_msg);
 | |
|   image_embedder_close(embedder, /* error_msg */ nullptr);
 | |
| }
 | |
| 
 | |
| }  // namespace
 |