Merge pull request #4943 from kinaryml:c-image-embedder-api
PiperOrigin-RevId: 580618718
This commit is contained in:
commit
d4d30768be
|
@ -66,6 +66,29 @@ void CppConvertToEmbeddingResult(
|
|||
}
|
||||
}
|
||||
|
||||
void CppConvertToCppEmbedding(
|
||||
const Embedding& in, // C struct as input
|
||||
mediapipe::tasks::components::containers::Embedding* out) {
|
||||
// Handle float embeddings
|
||||
if (in.float_embedding != nullptr) {
|
||||
out->float_embedding.assign(in.float_embedding,
|
||||
in.float_embedding + in.values_count);
|
||||
}
|
||||
|
||||
// Handle quantized embeddings
|
||||
if (in.quantized_embedding != nullptr) {
|
||||
out->quantized_embedding.assign(in.quantized_embedding,
|
||||
in.quantized_embedding + in.values_count);
|
||||
}
|
||||
|
||||
out->head_index = in.head_index;
|
||||
|
||||
// Copy head_name if it is present.
|
||||
if (in.head_name) {
|
||||
out->head_name = std::string(in.head_name);
|
||||
}
|
||||
}
|
||||
|
||||
void CppCloseEmbeddingResult(EmbeddingResult* in) {
|
||||
for (uint32_t i = 0; i < in->embeddings_count; ++i) {
|
||||
auto embedding_in = in->embeddings[i];
|
||||
|
|
|
@ -29,6 +29,10 @@ void CppConvertToEmbeddingResult(
|
|||
const mediapipe::tasks::components::containers::EmbeddingResult& in,
|
||||
EmbeddingResult* out);
|
||||
|
||||
void CppConvertToCppEmbedding(
|
||||
const Embedding& in,
|
||||
mediapipe::tasks::components::containers::Embedding* out);
|
||||
|
||||
void CppCloseEmbedding(Embedding* in);
|
||||
|
||||
void CppCloseEmbeddingResult(EmbeddingResult* in);
|
||||
|
|
|
@ -28,6 +28,7 @@ cc_library(
|
|||
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
"//mediapipe/tasks/c/core:base_options_converter",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/cc/text/text_embedder",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/status",
|
||||
|
|
|
@ -20,9 +20,11 @@ limitations under the License.
|
|||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
|
||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
|
||||
|
||||
namespace mediapipe::tasks::c::text::text_embedder {
|
||||
|
@ -30,12 +32,14 @@ namespace mediapipe::tasks::c::text::text_embedder {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
|
||||
using ::mediapipe::tasks::c::components::containers::CppConvertToCppEmbedding;
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppConvertToEmbeddingResult;
|
||||
using ::mediapipe::tasks::c::components::processors::
|
||||
CppConvertToEmbedderOptions;
|
||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::text::text_embedder::TextEmbedder;
|
||||
typedef ::mediapipe::tasks::components::containers::Embedding CppEmbedding;
|
||||
|
||||
int CppProcessError(absl::Status status, char** error_msg) {
|
||||
if (error_msg) {
|
||||
|
@ -91,6 +95,24 @@ int CppTextEmbedderClose(void* embedder, char** error_msg) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int CppTextEmbedderCosineSimilarity(const Embedding& u, const Embedding& v,
|
||||
double* similarity, char** error_msg) {
|
||||
CppEmbedding cpp_u;
|
||||
CppConvertToCppEmbedding(u, &cpp_u);
|
||||
CppEmbedding cpp_v;
|
||||
CppConvertToCppEmbedding(v, &cpp_v);
|
||||
auto status_or_similarity =
|
||||
mediapipe::tasks::text::text_embedder::TextEmbedder::CosineSimilarity(
|
||||
cpp_u, cpp_v);
|
||||
if (status_or_similarity.ok()) {
|
||||
*similarity = status_or_similarity.value();
|
||||
} else {
|
||||
ABSL_LOG(ERROR) << "Cannot compute cosine similarity.";
|
||||
return CppProcessError(status_or_similarity.status(), error_msg);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::text::text_embedder
|
||||
|
||||
extern "C" {
|
||||
|
@ -116,4 +138,10 @@ int text_embedder_close(void* embedder, char** error_ms) {
|
|||
embedder, error_ms);
|
||||
}
|
||||
|
||||
int text_embedder_cosine_similarity(const Embedding& u, const Embedding& v,
|
||||
double* similarity, char** error_msg) {
|
||||
return mediapipe::tasks::c::text::text_embedder::
|
||||
CppTextEmbedderCosineSimilarity(u, v, similarity, error_msg);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
|
|
@ -66,6 +66,17 @@ MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
|
|||
// allocated for the error message.
|
||||
MP_EXPORT int text_embedder_close(void* embedder, char** error_msg);
|
||||
|
||||
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||
// 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
MP_EXPORT int text_embedder_cosine_similarity(const Embedding& u,
|
||||
const Embedding& v,
|
||||
double* similarity,
|
||||
char** error_msg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
|
|
@ -32,7 +32,12 @@ using testing::HasSubstr;
|
|||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestBertModelPath[] =
|
||||
"mobilebert_embedding_with_metadata.tflite";
|
||||
constexpr char kTestString[] = "It's beautiful outside.";
|
||||
constexpr char kTestString0[] =
|
||||
"When you go to this restaurant, they hold the pancake upside-down "
|
||||
"before they hand it to you. It's a great gimmick.";
|
||||
constexpr char kTestString1[] =
|
||||
"Let's make a plan to steal the declaration of independence.";
|
||||
constexpr float kPrecision = 1e-3;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
|
@ -52,7 +57,7 @@ TEST(TextEmbedderTest, SmokeTest) {
|
|||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
TextEmbedderResult result;
|
||||
text_embedder_embed(embedder, kTestString, &result, /* error_msg */ nullptr);
|
||||
text_embedder_embed(embedder, kTestString0, &result, /* error_msg */ nullptr);
|
||||
EXPECT_EQ(result.embeddings_count, 1);
|
||||
EXPECT_EQ(result.embeddings[0].values_count, 512);
|
||||
|
||||
|
@ -60,6 +65,40 @@ TEST(TextEmbedderTest, SmokeTest) {
|
|||
text_embedder_close(embedder, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
TEST(TextEmbedderTest, SucceedsWithCosineSimilarity) {
|
||||
std::string model_path = GetFullPath(kTestBertModelPath);
|
||||
TextEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ false,
|
||||
/* quantize= */ false}};
|
||||
|
||||
void* embedder = text_embedder_create(&options,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
// Extract both embeddings.
|
||||
TextEmbedderResult result0;
|
||||
text_embedder_embed(embedder, kTestString0, &result0,
|
||||
/* error_msg */ nullptr);
|
||||
TextEmbedderResult result1;
|
||||
text_embedder_embed(embedder, kTestString1, &result1,
|
||||
/* error_msg */ nullptr);
|
||||
|
||||
// Check cosine similarity.
|
||||
double similarity;
|
||||
text_embedder_cosine_similarity(result0.embeddings[0], result1.embeddings[0],
|
||||
&similarity, nullptr);
|
||||
double expected_similarity = 0.98077;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kPrecision);
|
||||
|
||||
text_embedder_close_result(&result0);
|
||||
text_embedder_close_result(&result1);
|
||||
text_embedder_close(embedder, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
TEST(TextEmbedderTest, ErrorHandling) {
|
||||
// It is an error to set neither the asset buffer nor the path.
|
||||
TextEmbedderOptions options = {
|
||||
|
|
22
mediapipe/tasks/c/vision/core/BUILD
Normal file
22
mediapipe/tasks/c/vision/core/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
hdrs = ["common.h"],
|
||||
)
|
68
mediapipe/tasks/c/vision/core/common.h
Normal file
68
mediapipe/tasks/c/vision/core/common.h
Normal file
|
@ -0,0 +1,68 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_
|
||||
#define MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Supported image formats.
|
||||
enum ImageFormat {
|
||||
UNKNOWN = 0,
|
||||
SRGB = 1,
|
||||
SRGBA = 2,
|
||||
GRAY8 = 3,
|
||||
SBGRA = 11 // compatible with Flutter `bgra8888` format.
|
||||
};
|
||||
|
||||
// Supported processing modes.
|
||||
enum RunningMode {
|
||||
IMAGE = 1,
|
||||
VIDEO = 2,
|
||||
LIVE_STREAM = 3,
|
||||
};
|
||||
|
||||
// Structure to hold image frame.
|
||||
struct ImageFrame {
|
||||
enum ImageFormat format;
|
||||
const uint8_t* image_buffer;
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
// TODO: Add GPU buffer declaration and processing logic for it.
|
||||
struct GpuBuffer {
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
// The object to contain an image, realizes `OneOf` concept.
|
||||
struct MpImage {
|
||||
enum { IMAGE_FRAME, GPU_BUFFER } type;
|
||||
union {
|
||||
struct ImageFrame image_frame;
|
||||
struct GpuBuffer gpu_buffer;
|
||||
};
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_
|
|
@ -30,6 +30,7 @@ cc_library(
|
|||
"//mediapipe/tasks/c/components/processors:classifier_options_converter",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
"//mediapipe/tasks/c/core:base_options_converter",
|
||||
"//mediapipe/tasks/c/vision/core:common",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||
|
|
|
@ -16,11 +16,10 @@ limitations under the License.
|
|||
#ifndef MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_
|
||||
#define MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
#include "mediapipe/tasks/c/vision/core/common.h"
|
||||
|
||||
#ifndef MP_EXPORT
|
||||
#define MP_EXPORT __attribute__((visibility("default")))
|
||||
|
@ -32,46 +31,7 @@ extern "C" {
|
|||
|
||||
typedef ClassificationResult ImageClassifierResult;
|
||||
|
||||
// Supported image formats.
|
||||
enum ImageFormat {
|
||||
UNKNOWN = 0,
|
||||
SRGB = 1,
|
||||
SRGBA = 2,
|
||||
GRAY8 = 3,
|
||||
SBGRA = 11 // compatible with Flutter `bgra8888` format.
|
||||
};
|
||||
|
||||
// Supported processing modes.
|
||||
enum RunningMode {
|
||||
IMAGE = 1,
|
||||
VIDEO = 2,
|
||||
LIVE_STREAM = 3,
|
||||
};
|
||||
|
||||
// Structure to hold image frame.
|
||||
struct ImageFrame {
|
||||
enum ImageFormat format;
|
||||
const uint8_t* image_buffer;
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
// TODO: Add GPU buffer declaration and proccessing logic for it.
|
||||
struct GpuBuffer {
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
// The object to contain an image, realizes `OneOf` concept.
|
||||
struct MpImage {
|
||||
enum { IMAGE_FRAME, GPU_BUFFER } type;
|
||||
union {
|
||||
struct ImageFrame image_frame;
|
||||
struct GpuBuffer gpu_buffer;
|
||||
};
|
||||
};
|
||||
|
||||
// The options for configuring a Mediapipe image classifier task.
|
||||
// The options for configuring a MediaPipe image classifier task.
|
||||
struct ImageClassifierOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
|
@ -122,12 +82,39 @@ MP_EXPORT int image_classifier_classify_image(void* classifier,
|
|||
ImageClassifierResult* result,
|
||||
char** error_msg);
|
||||
|
||||
// Performs image classification on the provided video frame.
|
||||
// Only use this method when the ImageClassifier is created with the video
|
||||
// running mode.
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int image_classifier_classify_for_video(void* classifier,
|
||||
const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ImageClassifierResult* result,
|
||||
char** error_msg);
|
||||
|
||||
// Sends live image data to image classification, and the results will be
|
||||
// available via the `result_callback` provided in the ImageClassifierOptions.
|
||||
// Only use this method when the ImageClassifier is created with the live
|
||||
// stream running mode.
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
// sent to the object detector. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
// The `result_callback` provides:
|
||||
// - The classification results as an ImageClassifierResult object.
|
||||
// - The const reference to the corresponding input image that the image
|
||||
// classifier runs on. Note that the const reference to the image will no
|
||||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int image_classifier_classify_async(void* classifier,
|
||||
const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
|
|
67
mediapipe/tasks/c/vision/image_embedder/BUILD
Normal file
67
mediapipe/tasks/c/vision/image_embedder/BUILD
Normal file
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "image_embedder_lib",
|
||||
srcs = ["image_embedder.cc"],
|
||||
hdrs = ["image_embedder.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/tasks/c/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/c/components/containers:embedding_result_converter",
|
||||
"//mediapipe/tasks/c/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
"//mediapipe/tasks/c/core:base_options_converter",
|
||||
"//mediapipe/tasks/c/vision/core:common",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "image_embedder_test",
|
||||
srcs = ["image_embedder_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":image_embedder_lib",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/c/vision/core:common",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
303
mediapipe/tasks/c/vision/image_embedder/image_embedder.cc
Normal file
303
mediapipe/tasks/c/vision/image_embedder/image_embedder.cc
Normal file
|
@ -0,0 +1,303 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/vision/image_embedder/image_embedder.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
|
||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||
#include "mediapipe/tasks/c/vision/core/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_embedder/image_embedder.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe::tasks::c::vision::image_embedder {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
|
||||
using ::mediapipe::tasks::c::components::containers::CppConvertToCppEmbedding;
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppConvertToEmbeddingResult;
|
||||
using ::mediapipe::tasks::c::components::processors::
|
||||
CppConvertToEmbedderOptions;
|
||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::vision::CreateImageFromBuffer;
|
||||
using ::mediapipe::tasks::vision::core::RunningMode;
|
||||
using ::mediapipe::tasks::vision::image_embedder::ImageEmbedder;
|
||||
typedef ::mediapipe::tasks::components::containers::Embedding CppEmbedding;
|
||||
typedef ::mediapipe::tasks::vision::image_embedder::ImageEmbedderResult
|
||||
CppImageEmbedderResult;
|
||||
|
||||
int CppProcessError(absl::Status status, char** error_msg) {
|
||||
if (error_msg) {
|
||||
*error_msg = strdup(status.ToString().c_str());
|
||||
}
|
||||
return status.raw_code();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ImageEmbedder* CppImageEmbedderCreate(const ImageEmbedderOptions& options,
|
||||
char** error_msg) {
|
||||
auto cpp_options = std::make_unique<
|
||||
::mediapipe::tasks::vision::image_embedder::ImageEmbedderOptions>();
|
||||
|
||||
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
|
||||
CppConvertToEmbedderOptions(options.embedder_options,
|
||||
&cpp_options->embedder_options);
|
||||
cpp_options->running_mode = static_cast<RunningMode>(options.running_mode);
|
||||
|
||||
// Enable callback for processing live stream data when the running mode is
|
||||
// set to RunningMode::LIVE_STREAM.
|
||||
if (cpp_options->running_mode == RunningMode::LIVE_STREAM) {
|
||||
if (options.result_callback == nullptr) {
|
||||
const absl::Status status = absl::InvalidArgumentError(
|
||||
"Provided null pointer to callback function.");
|
||||
ABSL_LOG(ERROR) << "Failed to create ImageEmbedder: " << status;
|
||||
CppProcessError(status, error_msg);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ImageEmbedderOptions::result_callback_fn result_callback =
|
||||
options.result_callback;
|
||||
cpp_options->result_callback =
|
||||
[result_callback](absl::StatusOr<CppImageEmbedderResult> cpp_result,
|
||||
const Image& image, int64_t timestamp) {
|
||||
char* error_msg = nullptr;
|
||||
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR)
|
||||
<< "Embedding extraction failed: " << cpp_result.status();
|
||||
CppProcessError(cpp_result.status(), &error_msg);
|
||||
result_callback(nullptr, MpImage(), timestamp, error_msg);
|
||||
free(error_msg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Result is valid for the lifetime of the callback function.
|
||||
ImageEmbedderResult result;
|
||||
CppConvertToEmbeddingResult(*cpp_result, &result);
|
||||
|
||||
const auto& image_frame = image.GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {
|
||||
.format = static_cast<::ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
result_callback(&result, mp_image, timestamp,
|
||||
/* error_msg= */ nullptr);
|
||||
|
||||
CppCloseEmbeddingResult(&result);
|
||||
};
|
||||
}
|
||||
|
||||
auto embedder = ImageEmbedder::Create(std::move(cpp_options));
|
||||
if (!embedder.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create ImageEmbedder: " << embedder.status();
|
||||
CppProcessError(embedder.status(), error_msg);
|
||||
return nullptr;
|
||||
}
|
||||
return embedder->release();
|
||||
}
|
||||
|
||||
int CppImageEmbedderEmbed(void* embedder, const MpImage* image,
|
||||
ImageEmbedderResult* result, char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
const absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet.");
|
||||
|
||||
ABSL_LOG(ERROR) << "Embedding extraction failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
|
||||
auto cpp_result = cpp_embedder->Embed(*img);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
|
||||
return CppProcessError(cpp_result.status(), error_msg);
|
||||
}
|
||||
CppConvertToEmbeddingResult(*cpp_result, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CppImageEmbedderEmbedForVideo(void* embedder, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ImageEmbedderResult* result,
|
||||
char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet");
|
||||
|
||||
ABSL_LOG(ERROR) << "Embedding extraction failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
|
||||
auto cpp_result = cpp_embedder->EmbedForVideo(*img, timestamp_ms);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
|
||||
return CppProcessError(cpp_result.status(), error_msg);
|
||||
}
|
||||
CppConvertToEmbeddingResult(*cpp_result, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CppImageEmbedderEmbedAsync(void* embedder, const MpImage* image,
|
||||
int64_t timestamp_ms, char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet");
|
||||
|
||||
ABSL_LOG(ERROR) << "Embedding extraction failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
|
||||
auto cpp_result = cpp_embedder->EmbedAsync(*img, timestamp_ms);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Data preparation for the embedding extraction failed: "
|
||||
<< cpp_result;
|
||||
return CppProcessError(cpp_result, error_msg);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CppImageEmbedderCloseResult(ImageEmbedderResult* result) {
|
||||
CppCloseEmbeddingResult(result);
|
||||
}
|
||||
|
||||
int CppImageEmbedderClose(void* embedder, char** error_msg) {
|
||||
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
|
||||
auto result = cpp_embedder->Close();
|
||||
if (!result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to close ImageEmbedder: " << result;
|
||||
return CppProcessError(result, error_msg);
|
||||
}
|
||||
delete cpp_embedder;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CppImageEmbedderCosineSimilarity(const Embedding& u, const Embedding& v,
|
||||
double* similarity, char** error_msg) {
|
||||
CppEmbedding cpp_u;
|
||||
CppConvertToCppEmbedding(u, &cpp_u);
|
||||
CppEmbedding cpp_v;
|
||||
CppConvertToCppEmbedding(v, &cpp_v);
|
||||
auto status_or_similarity =
|
||||
mediapipe::tasks::vision::image_embedder::ImageEmbedder::CosineSimilarity(
|
||||
cpp_u, cpp_v);
|
||||
if (status_or_similarity.ok()) {
|
||||
*similarity = status_or_similarity.value();
|
||||
} else {
|
||||
ABSL_LOG(ERROR) << "Cannot compute cosine similarity.";
|
||||
return CppProcessError(status_or_similarity.status(), error_msg);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::vision::image_embedder
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* image_embedder_create(struct ImageEmbedderOptions* options,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderCreate(
|
||||
*options, error_msg);
|
||||
}
|
||||
|
||||
int image_embedder_embed_image(void* embedder, const MpImage* image,
|
||||
ImageEmbedderResult* result, char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderEmbed(
|
||||
embedder, image, result, error_msg);
|
||||
}
|
||||
|
||||
int image_embedder_embed_for_video(void* embedder, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ImageEmbedderResult* result,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_embedder::
|
||||
CppImageEmbedderEmbedForVideo(embedder, image, timestamp_ms, result,
|
||||
error_msg);
|
||||
}
|
||||
|
||||
int image_embedder_embed_async(void* embedder, const MpImage* image,
|
||||
int64_t timestamp_ms, char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_embedder::
|
||||
CppImageEmbedderEmbedAsync(embedder, image, timestamp_ms, error_msg);
|
||||
}
|
||||
|
||||
void image_embedder_close_result(ImageEmbedderResult* result) {
|
||||
mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderCloseResult(
|
||||
result);
|
||||
}
|
||||
|
||||
int image_embedder_close(void* embedder, char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderClose(
|
||||
embedder, error_msg);
|
||||
}
|
||||
|
||||
int image_embedder_cosine_similarity(const Embedding& u, const Embedding& v,
|
||||
double* similarity, char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_embedder::
|
||||
CppImageEmbedderCosineSimilarity(u, v, similarity, error_msg);
|
||||
}
|
||||
|
||||
} // extern "C"
|
148
mediapipe/tasks/c/vision/image_embedder/image_embedder.h
Normal file
148
mediapipe/tasks/c/vision/image_embedder/image_embedder.h
Normal file
|
@ -0,0 +1,148 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_
|
||||
#define MEDIAPIPE_TASKS_C_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
#include "mediapipe/tasks/c/vision/core/common.h"
|
||||
|
||||
#ifndef MP_EXPORT
|
||||
#define MP_EXPORT __attribute__((visibility("default")))
|
||||
#endif // MP_EXPORT
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef EmbeddingResult ImageEmbedderResult;
|
||||
|
||||
// The options for configuring a MediaPipe image embedder task.
|
||||
struct ImageEmbedderOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
struct BaseOptions base_options;
|
||||
|
||||
// The running mode of the task. Default to the image mode.
|
||||
// Image embedder has three running modes:
|
||||
// 1) The image mode for embedding image on single image inputs.
|
||||
// 2) The video mode for embedding image on the decoded frames of a video.
|
||||
// 3) The live stream mode for embedding image on the live stream of input
|
||||
// data, such as from camera. In this mode, the "result_callback" below must
|
||||
// be specified to receive the embedding results asynchronously.
|
||||
RunningMode running_mode;
|
||||
|
||||
// Options for configuring the embedder behavior, such as l2_normalize and
|
||||
// quantize.
|
||||
struct EmbedderOptions embedder_options;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM. Arguments of the callback function include:
|
||||
// the pointer to embedding result, the image that result was obtained
|
||||
// on, the timestamp relevant to embedding extraction results and pointer to
|
||||
// error message in case of any failure. The validity of the passed arguments
|
||||
// is true for the lifetime of the callback function.
|
||||
//
|
||||
// A caller is responsible for closing image embedder result.
|
||||
typedef void (*result_callback_fn)(ImageEmbedderResult* result,
|
||||
const MpImage image, int64_t timestamp_ms,
|
||||
char* error_msg);
|
||||
result_callback_fn result_callback;
|
||||
};
|
||||
|
||||
// Creates an ImageEmbedder from provided `options`.
|
||||
// Returns a pointer to the image embedder on success.
|
||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT void* image_embedder_create(struct ImageEmbedderOptions* options,
|
||||
char** error_msg);
|
||||
|
||||
// Performs embedding extraction on the input `image`. Returns `0` on success.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int image_embedder_embed_image(void* embedder, const MpImage* image,
|
||||
ImageEmbedderResult* result,
|
||||
char** error_msg);
|
||||
|
||||
// Performs embedding extraction on the provided video frame.
|
||||
// Only use this method when the ImageEmbedder is created with the video
|
||||
// running mode.
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int image_embedder_embed_for_video(void* embedder,
|
||||
const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ImageEmbedderResult* result,
|
||||
char** error_msg);
|
||||
|
||||
// Sends live image data to embedder, and the results will be available via
|
||||
// the `result_callback` provided in the ImageEmbedderOptions.
|
||||
// Only use this method when the ImageEmbedder is created with the live
|
||||
// stream running mode.
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
// sent to the object detector. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
// The `result_callback` provides
|
||||
// - The embedding results as a `ImageEmbedderResult` object.
|
||||
// - The const reference to the corresponding input image that the image
|
||||
// embedder runs on. Note that the const reference to the image will no
|
||||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int image_embedder_embed_async(void* embedder, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
char** error_msg);
|
||||
|
||||
// Frees the memory allocated inside a ImageEmbedderResult result.
|
||||
// Does not free the result pointer itself.
|
||||
MP_EXPORT void image_embedder_close_result(ImageEmbedderResult* result);
|
||||
|
||||
// Frees image embedder.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int image_embedder_close(void* embedder, char** error_msg);
|
||||
|
||||
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||
// 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
MP_EXPORT int image_embedder_cosine_similarity(const Embedding& u,
|
||||
const Embedding& v,
|
||||
double* similarity,
|
||||
char** error_msg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_
|
302
mediapipe/tasks/c/vision/image_embedder/image_embedder_test.cc
Normal file
302
mediapipe/tasks/c/vision/image_embedder/image_embedder_test.cc
Normal file
|
@ -0,0 +1,302 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/vision/image_embedder/image_embedder.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/vision/core/common.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::vision::DecodeImageFromFile;
|
||||
using testing::HasSubstr;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kModelName[] = "mobilenet_v3_small_100_224_embedder.tflite";
|
||||
constexpr char kImageFile[] = "burger.jpg";
|
||||
constexpr float kPrecision = 1e-6;
|
||||
constexpr int kIterations = 100;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
// Utility function to check the sizes, head_index and head_names of a result
|
||||
// produced by kMobileNetV3Embedder.
|
||||
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
|
||||
EXPECT_EQ(result.embeddings_count, 1);
|
||||
EXPECT_EQ(result.embeddings[0].head_index, 0);
|
||||
EXPECT_EQ(std::string{result.embeddings[0].head_name}, "feature");
|
||||
if (quantized) {
|
||||
EXPECT_EQ(result.embeddings[0].values_count, 1024);
|
||||
} else {
|
||||
EXPECT_EQ(result.embeddings[0].values_count, 1024);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ImageEmbedderTest, ImageModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ImageEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::IMAGE,
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ true,
|
||||
/* quantize= */ false}};
|
||||
|
||||
void* embedder = image_embedder_create(&options,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {
|
||||
.format = static_cast<ImageFormat>(
|
||||
image->GetImageFrameSharedPtr()->Format()),
|
||||
.image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
|
||||
.width = image->GetImageFrameSharedPtr()->Width(),
|
||||
.height = image->GetImageFrameSharedPtr()->Height()}};
|
||||
|
||||
ImageEmbedderResult result;
|
||||
image_embedder_embed_image(embedder, &mp_image, &result,
|
||||
/* error_msg */ nullptr);
|
||||
CheckMobileNetV3Result(result, false);
|
||||
EXPECT_NEAR(result.embeddings[0].float_embedding[0], -0.0142344, kPrecision);
|
||||
image_embedder_close_result(&result);
|
||||
image_embedder_close(embedder, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
TEST(ImageEmbedderTest, SucceedsWithCosineSimilarity) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
|
||||
ASSERT_TRUE(image.ok());
|
||||
const auto crop = DecodeImageFromFile(GetFullPath("burger_crop.jpg"));
|
||||
ASSERT_TRUE(crop.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ImageEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::IMAGE,
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ true,
|
||||
/* quantize= */ false}};
|
||||
|
||||
void* embedder = image_embedder_create(&options,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {
|
||||
.format = static_cast<ImageFormat>(
|
||||
image->GetImageFrameSharedPtr()->Format()),
|
||||
.image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
|
||||
.width = image->GetImageFrameSharedPtr()->Width(),
|
||||
.height = image->GetImageFrameSharedPtr()->Height()}};
|
||||
|
||||
const MpImage mp_crop = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {
|
||||
.format = static_cast<ImageFormat>(
|
||||
crop->GetImageFrameSharedPtr()->Format()),
|
||||
.image_buffer = crop->GetImageFrameSharedPtr()->PixelData(),
|
||||
.width = crop->GetImageFrameSharedPtr()->Width(),
|
||||
.height = crop->GetImageFrameSharedPtr()->Height()}};
|
||||
|
||||
// Extract both embeddings.
|
||||
ImageEmbedderResult image_result;
|
||||
image_embedder_embed_image(embedder, &mp_image, &image_result,
|
||||
/* error_msg */ nullptr);
|
||||
ImageEmbedderResult crop_result;
|
||||
image_embedder_embed_image(embedder, &mp_crop, &crop_result,
|
||||
/* error_msg */ nullptr);
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, false);
|
||||
CheckMobileNetV3Result(crop_result, false);
|
||||
// Check cosine similarity.
|
||||
double similarity;
|
||||
image_embedder_cosine_similarity(image_result.embeddings[0],
|
||||
crop_result.embeddings[0], &similarity,
|
||||
/* error_msg */ nullptr);
|
||||
double expected_similarity = 0.925519;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kPrecision);
|
||||
image_embedder_close_result(&image_result);
|
||||
image_embedder_close_result(&crop_result);
|
||||
image_embedder_close(embedder, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
TEST(ImageEmbedderTest, VideoModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ImageEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::VIDEO,
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ true,
|
||||
/* quantize= */ false}};
|
||||
|
||||
void* embedder = image_embedder_create(&options,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
ImageEmbedderResult result;
|
||||
image_embedder_embed_for_video(embedder, &mp_image, i, &result,
|
||||
/* error_msg */ nullptr);
|
||||
CheckMobileNetV3Result(result, false);
|
||||
EXPECT_NEAR(result.embeddings[0].float_embedding[0], -0.0142344,
|
||||
kPrecision);
|
||||
image_embedder_close_result(&result);
|
||||
}
|
||||
image_embedder_close(embedder, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
// A structure to support LiveStreamModeTest below. This structure holds a
|
||||
// static method `Fn` for a callback function of C API. A `static` qualifier
|
||||
// allows to take an address of the method to follow API style. Another static
|
||||
// struct member is `last_timestamp` that is used to verify that current
|
||||
// timestamp is greater than the previous one.
|
||||
struct LiveStreamModeCallback {
|
||||
static int64_t last_timestamp;
|
||||
static void Fn(ImageEmbedderResult* embedder_result, const MpImage image,
|
||||
int64_t timestamp, char* error_msg) {
|
||||
ASSERT_NE(embedder_result, nullptr);
|
||||
ASSERT_EQ(error_msg, nullptr);
|
||||
CheckMobileNetV3Result(*embedder_result, false);
|
||||
EXPECT_NEAR(embedder_result->embeddings[0].float_embedding[0], -0.0142344,
|
||||
kPrecision);
|
||||
EXPECT_GT(image.image_frame.width, 0);
|
||||
EXPECT_GT(image.image_frame.height, 0);
|
||||
EXPECT_GT(timestamp, last_timestamp);
|
||||
last_timestamp++;
|
||||
}
|
||||
};
|
||||
int64_t LiveStreamModeCallback::last_timestamp = -1;
|
||||
|
||||
TEST(ImageEmbedderTest, LiveStreamModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
|
||||
ImageEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::LIVE_STREAM,
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ true,
|
||||
/* quantize= */ false},
|
||||
/* result_callback= */ LiveStreamModeCallback::Fn,
|
||||
};
|
||||
|
||||
void* embedder = image_embedder_create(&options,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
EXPECT_GE(image_embedder_embed_async(embedder, &mp_image, i,
|
||||
/* error_msg */ nullptr),
|
||||
0);
|
||||
}
|
||||
image_embedder_close(embedder, /* error_msg */ nullptr);
|
||||
|
||||
// Due to the flow limiter, the total of outputs might be smaller than the
|
||||
// number of iterations.
|
||||
EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
|
||||
EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
|
||||
}
|
||||
|
||||
TEST(ImageEmbedderTest, InvalidArgumentHandling) {
|
||||
// It is an error to set neither the asset buffer nor the path.
|
||||
ImageEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ nullptr},
|
||||
/* embedder_options= */ {},
|
||||
};
|
||||
|
||||
char* error_msg;
|
||||
void* embedder = image_embedder_create(&options, &error_msg);
|
||||
EXPECT_EQ(embedder, nullptr);
|
||||
|
||||
EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify"));
|
||||
|
||||
free(error_msg);
|
||||
}
|
||||
|
||||
TEST(ImageEmbedderTest, FailedEmbeddingHandling) {
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ImageEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::IMAGE,
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ false,
|
||||
/* quantize= */ false},
|
||||
};
|
||||
|
||||
void* embedder = image_embedder_create(&options,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
|
||||
ImageEmbedderResult result;
|
||||
char* error_msg;
|
||||
image_embedder_embed_image(embedder, &mp_image, &result, &error_msg);
|
||||
EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet."));
|
||||
free(error_msg);
|
||||
image_embedder_close(embedder, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user