mediapipe/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc
2022-10-24 10:14:05 -07:00

628 lines
27 KiB
C++

/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/image_embedder/image_embedder.h"
#include <memory>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_embedder {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kMobileNetV3Embedder[] =
"mobilenet_v3_small_100_224_embedder.tflite";
constexpr double kSimilarityTolerancy = 1e-6;
// Utility function to check the sizes, head_index and head_names of a result
// procuded by kMobileNetV3Embedder.
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) {
EXPECT_EQ(result.embeddings().size(), 1);
EXPECT_EQ(result.embeddings(0).head_index(), 0);
EXPECT_EQ(result.embeddings(0).head_name(), "feature");
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
if (quantized) {
EXPECT_EQ(
result.embeddings(0).entries(0).quantized_embedding().values().size(),
1024);
} else {
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(),
1024);
}
}
// A custom OpResolver only containing the Ops required by the test model.
class MobileNetV3OpResolver : public ::tflite::MutableOpResolver {
public:
MobileNetV3OpResolver() {
AddBuiltin(::tflite::BuiltinOperator_MUL,
::tflite::ops::builtin::Register_MUL());
AddBuiltin(::tflite::BuiltinOperator_SUB,
::tflite::ops::builtin::Register_SUB());
AddBuiltin(::tflite::BuiltinOperator_CONV_2D,
::tflite::ops::builtin::Register_CONV_2D());
AddBuiltin(::tflite::BuiltinOperator_HARD_SWISH,
::tflite::ops::builtin::Register_HARD_SWISH());
AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D());
AddBuiltin(::tflite::BuiltinOperator_MEAN,
::tflite::ops::builtin::Register_MEAN());
AddBuiltin(::tflite::BuiltinOperator_ADD,
::tflite::ops::builtin::Register_ADD());
AddBuiltin(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
::tflite::ops::builtin::Register_AVERAGE_POOL_2D());
AddBuiltin(::tflite::BuiltinOperator_RESHAPE,
::tflite::ops::builtin::Register_RESHAPE());
}
MobileNetV3OpResolver(const MobileNetV3OpResolver& r) = delete;
};
// A custom OpResolver missing Ops required by the test model.
class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
public:
MobileNetV3OpResolverMissingOps() {
AddBuiltin(::tflite::BuiltinOperator_SOFTMAX,
::tflite::ops::builtin::Register_SOFTMAX());
}
MobileNetV3OpResolverMissingOps(const MobileNetV3OpResolverMissingOps& r) =
delete;
};
class CreateTest : public tflite_shims::testing::Test {};
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->base_options.op_resolver = std::make_unique<MobileNetV3OpResolver>();
MP_ASSERT_OK(ImageEmbedder::Create(std::move(options)));
}
TEST_F(CreateTest, FailsWithSelectiveOpResolverMissingOps) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->base_options.op_resolver =
std::make_unique<MobileNetV3OpResolverMissingOps>();
auto image_embedder = ImageEmbedder::Create(std::move(options));
EXPECT_EQ(image_embedder.status().code(), absl::StatusCode::kInternal);
EXPECT_THAT(image_embedder.status().message(),
HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
}
TEST_F(CreateTest, FailsWithMissingModel) {
auto image_embedder =
ImageEmbedder::Create(std::make_unique<ImageEmbedderOptions>());
EXPECT_EQ(image_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
image_embedder.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
for (auto running_mode :
{core::RunningMode::IMAGE, core::RunningMode::VIDEO}) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = running_mode;
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
const Image& image, int64 timestamp_ms) {};
auto image_embedder = ImageEmbedder::Create(std::move(options));
EXPECT_EQ(image_embedder.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
image_embedder.status().message(),
HasSubstr("a user-defined result callback shouldn't be provided"));
EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
}
TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM;
auto image_embedder = ImageEmbedder::Create(std::move(options));
EXPECT_EQ(image_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(image_embedder.status().message(),
HasSubstr("a user-defined result callback must be provided"));
EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class ImageModeTest : public tflite_shims::testing::Test {};
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
auto results = image_embedder->EmbedForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = image_embedder->EmbedAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(image_embedder->Close());
}
TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a crop of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->embedder_options.l2_normalize = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a crop of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithQuantization) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->embedder_options.quantize = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a crop of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, true);
CheckMobileNetV3Result(crop_result, true);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.926791;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a crop of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& image_result,
image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.999931;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a rotated version of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
// Region-of-interest corresponding to burger_crop.jpg.
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
auto results = image_embedder->Embed(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = image_embedder->EmbedAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(image_embedder->Close());
}
TEST_F(VideoModeTest, FailsWithOutOfOrderInputTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
MP_ASSERT_OK(image_embedder->EmbedForVideo(image, 1));
auto results = image_embedder->EmbedForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("timestamp must be monotonically increasing"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInvalidTimestampError))));
MP_ASSERT_OK(image_embedder->EmbedForVideo(image, 2));
MP_ASSERT_OK(image_embedder->Close());
}
TEST_F(VideoModeTest, Succeeds) {
int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
EmbeddingResult previous_results;
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results,
image_embedder->EmbedForVideo(image, i));
CheckMobileNetV3Result(results, false);
if (i > 0) {
MP_ASSERT_OK_AND_ASSIGN(double similarity,
ImageEmbedder::CosineSimilarity(
results.embeddings(0).entries(0),
previous_results.embeddings(0).entries(0)));
double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
previous_results = results;
}
MP_ASSERT_OK(image_embedder->Close());
}
class LiveStreamModeTest : public tflite_shims::testing::Test {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
auto results = image_embedder->Embed(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = image_embedder->EmbedForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(image_embedder->Close());
}
TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
MP_ASSERT_OK(image_embedder->EmbedAsync(image, 1));
auto status = image_embedder->EmbedAsync(image, 0);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
HasSubstr("timestamp must be monotonically increasing"));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInvalidTimestampError))));
MP_ASSERT_OK(image_embedder->EmbedAsync(image, 2));
MP_ASSERT_OK(image_embedder->Close());
}
struct LiveStreamModeResults {
EmbeddingResult embedding_result;
std::pair<int, int> image_size;
int64 timestamp_ms;
};
TEST_F(LiveStreamModeTest, Succeeds) {
int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
std::vector<LiveStreamModeResults> results;
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[&results](absl::StatusOr<EmbeddingResult> embedding_result,
const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(embedding_result.status());
results.push_back(
{.embedding_result = std::move(embedding_result).value(),
.image_size = {image.width(), image.height()},
.timestamp_ms = timestamp_ms});
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK(image_embedder->EmbedAsync(image, i));
}
MP_ASSERT_OK(image_embedder->Close());
// Due to the flow limiter, the total of outputs will be smaller than the
// number of iterations.
ASSERT_LE(results.size(), iterations);
ASSERT_GT(results.size(), 0);
int64 timestamp_ms = -1;
for (int i = 0; i < results.size(); ++i) {
const auto& result = results[i];
EXPECT_GT(result.timestamp_ms, timestamp_ms);
timestamp_ms = result.timestamp_ms;
EXPECT_EQ(result.image_size.first, image.width());
EXPECT_EQ(result.image_size.second, image.height());
CheckMobileNetV3Result(result.embedding_result, false);
if (i > 0) {
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(
result.embedding_result.embeddings(0).entries(0),
results[i - 1].embedding_result.embeddings(0).entries(0)));
double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
}
}
} // namespace
} // namespace image_embedder
} // namespace vision
} // namespace tasks
} // namespace mediapipe