Merge pull request #4943 from kinaryml:c-image-embedder-api

PiperOrigin-RevId: 580618718
This commit is contained in:
Copybara-Service 2023-11-08 12:35:05 -08:00
commit d4d30768be
14 changed files with 1048 additions and 44 deletions

View File

@ -66,6 +66,29 @@ void CppConvertToEmbeddingResult(
}
}
void CppConvertToCppEmbedding(
const Embedding& in, // C struct as input
mediapipe::tasks::components::containers::Embedding* out) {
// Handle float embeddings
if (in.float_embedding != nullptr) {
out->float_embedding.assign(in.float_embedding,
in.float_embedding + in.values_count);
}
// Handle quantized embeddings
if (in.quantized_embedding != nullptr) {
out->quantized_embedding.assign(in.quantized_embedding,
in.quantized_embedding + in.values_count);
}
out->head_index = in.head_index;
// Copy head_name if it is present.
if (in.head_name) {
out->head_name = std::string(in.head_name);
}
}
void CppCloseEmbeddingResult(EmbeddingResult* in) {
for (uint32_t i = 0; i < in->embeddings_count; ++i) {
auto embedding_in = in->embeddings[i];

View File

@ -29,6 +29,10 @@ void CppConvertToEmbeddingResult(
const mediapipe::tasks::components::containers::EmbeddingResult& in,
EmbeddingResult* out);
void CppConvertToCppEmbedding(
const Embedding& in,
mediapipe::tasks::components::containers::Embedding* out);
void CppCloseEmbedding(Embedding* in);
void CppCloseEmbeddingResult(EmbeddingResult* in);

View File

@ -28,6 +28,7 @@ cc_library(
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/text/text_embedder",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",

View File

@ -20,9 +20,11 @@ limitations under the License.
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
namespace mediapipe::tasks::c::text::text_embedder {
@ -30,12 +32,14 @@ namespace mediapipe::tasks::c::text::text_embedder {
namespace {
using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
using ::mediapipe::tasks::c::components::containers::CppConvertToCppEmbedding;
using ::mediapipe::tasks::c::components::containers::
CppConvertToEmbeddingResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToEmbedderOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::text::text_embedder::TextEmbedder;
typedef ::mediapipe::tasks::components::containers::Embedding CppEmbedding;
int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) {
@ -91,6 +95,24 @@ int CppTextEmbedderClose(void* embedder, char** error_msg) {
return 0;
}
int CppTextEmbedderCosineSimilarity(const Embedding& u, const Embedding& v,
double* similarity, char** error_msg) {
CppEmbedding cpp_u;
CppConvertToCppEmbedding(u, &cpp_u);
CppEmbedding cpp_v;
CppConvertToCppEmbedding(v, &cpp_v);
auto status_or_similarity =
mediapipe::tasks::text::text_embedder::TextEmbedder::CosineSimilarity(
cpp_u, cpp_v);
if (status_or_similarity.ok()) {
*similarity = status_or_similarity.value();
} else {
ABSL_LOG(ERROR) << "Cannot compute cosine similarity.";
return CppProcessError(status_or_similarity.status(), error_msg);
}
return 0;
}
} // namespace mediapipe::tasks::c::text::text_embedder
extern "C" {
@ -116,4 +138,10 @@ int text_embedder_close(void* embedder, char** error_ms) {
embedder, error_ms);
}
int text_embedder_cosine_similarity(const Embedding& u, const Embedding& v,
double* similarity, char** error_msg) {
return mediapipe::tasks::c::text::text_embedder::
CppTextEmbedderCosineSimilarity(u, v, similarity, error_msg);
}
} // extern "C"

View File

@ -66,6 +66,17 @@ MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
// allocated for the error message.
MP_EXPORT int text_embedder_close(void* embedder, char** error_msg);
// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
MP_EXPORT int text_embedder_cosine_similarity(const Embedding& u,
const Embedding& v,
double* similarity,
char** error_msg);
#ifdef __cplusplus
} // extern C
#endif

View File

@ -32,7 +32,12 @@ using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] =
"mobilebert_embedding_with_metadata.tflite";
constexpr char kTestString[] = "It's beautiful outside.";
constexpr char kTestString0[] =
"When you go to this restaurant, they hold the pancake upside-down "
"before they hand it to you. It's a great gimmick.";
constexpr char kTestString1[] =
"Let's make a plan to steal the declaration of independence.";
constexpr float kPrecision = 1e-3;
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
@ -52,7 +57,7 @@ TEST(TextEmbedderTest, SmokeTest) {
EXPECT_NE(embedder, nullptr);
TextEmbedderResult result;
text_embedder_embed(embedder, kTestString, &result, /* error_msg */ nullptr);
text_embedder_embed(embedder, kTestString0, &result, /* error_msg */ nullptr);
EXPECT_EQ(result.embeddings_count, 1);
EXPECT_EQ(result.embeddings[0].values_count, 512);
@ -60,6 +65,40 @@ TEST(TextEmbedderTest, SmokeTest) {
text_embedder_close(embedder, /* error_msg */ nullptr);
}
TEST(TextEmbedderTest, SucceedsWithCosineSimilarity) {
std::string model_path = GetFullPath(kTestBertModelPath);
TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* embedder_options= */
{/* l2_normalize= */ false,
/* quantize= */ false}};
void* embedder = text_embedder_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr);
// Extract both embeddings.
TextEmbedderResult result0;
text_embedder_embed(embedder, kTestString0, &result0,
/* error_msg */ nullptr);
TextEmbedderResult result1;
text_embedder_embed(embedder, kTestString1, &result1,
/* error_msg */ nullptr);
// Check cosine similarity.
double similarity;
text_embedder_cosine_similarity(result0.embeddings[0], result1.embeddings[0],
&similarity, nullptr);
double expected_similarity = 0.98077;
EXPECT_LE(abs(similarity - expected_similarity), kPrecision);
text_embedder_close_result(&result0);
text_embedder_close_result(&result1);
text_embedder_close(embedder, /* error_msg */ nullptr);
}
TEST(TextEmbedderTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path.
TextEmbedderOptions options = {

View File

@ -0,0 +1,22 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "common",
hdrs = ["common.h"],
)

View File

@ -0,0 +1,68 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_
#define MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_
#include <cstdint>
#ifdef __cplusplus
extern "C" {
#endif
// Supported image formats.
enum ImageFormat {
UNKNOWN = 0,
SRGB = 1,
SRGBA = 2,
GRAY8 = 3,
SBGRA = 11 // compatible with Flutter `bgra8888` format.
};
// Supported processing modes.
enum RunningMode {
IMAGE = 1,
VIDEO = 2,
LIVE_STREAM = 3,
};
// Structure to hold image frame.
struct ImageFrame {
enum ImageFormat format;
const uint8_t* image_buffer;
int width;
int height;
};
// TODO: Add GPU buffer declaration and processing logic for it.
struct GpuBuffer {
int width;
int height;
};
// The object to contain an image, realizes `OneOf` concept.
struct MpImage {
enum { IMAGE_FRAME, GPU_BUFFER } type;
union {
struct ImageFrame image_frame;
struct GpuBuffer gpu_buffer;
};
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_

View File

@ -30,6 +30,7 @@ cc_library(
"//mediapipe/tasks/c/components/processors:classifier_options_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/image_classifier",
"//mediapipe/tasks/cc/vision/utils:image_utils",

View File

@ -16,11 +16,10 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_
#define MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_
#include <cstdint>
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
@ -32,46 +31,7 @@ extern "C" {
typedef ClassificationResult ImageClassifierResult;
// Supported image formats.
enum ImageFormat {
UNKNOWN = 0,
SRGB = 1,
SRGBA = 2,
GRAY8 = 3,
SBGRA = 11 // compatible with Flutter `bgra8888` format.
};
// Supported processing modes.
enum RunningMode {
IMAGE = 1,
VIDEO = 2,
LIVE_STREAM = 3,
};
// Structure to hold image frame.
struct ImageFrame {
enum ImageFormat format;
const uint8_t* image_buffer;
int width;
int height;
};
// TODO: Add GPU buffer declaration and proccessing logic for it.
struct GpuBuffer {
int width;
int height;
};
// The object to contain an image, realizes `OneOf` concept.
struct MpImage {
enum { IMAGE_FRAME, GPU_BUFFER } type;
union {
struct ImageFrame image_frame;
struct GpuBuffer gpu_buffer;
};
};
// The options for configuring a Mediapipe image classifier task.
// The options for configuring a MediaPipe image classifier task.
struct ImageClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
@ -122,12 +82,39 @@ MP_EXPORT int image_classifier_classify_image(void* classifier,
ImageClassifierResult* result,
char** error_msg);
// Performs image classification on the provided video frame.
// Only use this method when the ImageClassifier is created with the video
// running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_classifier_classify_for_video(void* classifier,
const MpImage* image,
int64_t timestamp_ms,
ImageClassifierResult* result,
char** error_msg);
// Sends live image data to image classification, and the results will be
// available via the `result_callback` provided in the ImageClassifierOptions.
// Only use this method when the ImageClassifier is created with the live
// stream running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the object detector. The input timestamps must be monotonically
// increasing.
// The `result_callback` provides:
// - The classification results as an ImageClassifierResult object.
// - The const reference to the corresponding input image that the image
// classifier runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_classifier_classify_async(void* classifier,
const MpImage* image,
int64_t timestamp_ms,

View File

@ -0,0 +1,67 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "image_embedder_lib",
srcs = ["image_embedder.cc"],
hdrs = ["image_embedder.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/tasks/c/components/containers:embedding_result",
"//mediapipe/tasks/c/components/containers:embedding_result_converter",
"//mediapipe/tasks/c/components/processors:embedder_options",
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/image_embedder",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "image_embedder_test",
srcs = ["image_embedder_test.cc"],
data = [
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
linkstatic = 1,
deps = [
":image_embedder_lib",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,303 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/image_embedder/image_embedder.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <utility>
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/image_embedder/image_embedder.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe::tasks::c::vision::image_embedder {
namespace {
using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
using ::mediapipe::tasks::c::components::containers::CppConvertToCppEmbedding;
using ::mediapipe::tasks::c::components::containers::
CppConvertToEmbeddingResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToEmbedderOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::vision::CreateImageFromBuffer;
using ::mediapipe::tasks::vision::core::RunningMode;
using ::mediapipe::tasks::vision::image_embedder::ImageEmbedder;
typedef ::mediapipe::tasks::components::containers::Embedding CppEmbedding;
typedef ::mediapipe::tasks::vision::image_embedder::ImageEmbedderResult
CppImageEmbedderResult;
int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) {
*error_msg = strdup(status.ToString().c_str());
}
return status.raw_code();
}
} // namespace
ImageEmbedder* CppImageEmbedderCreate(const ImageEmbedderOptions& options,
char** error_msg) {
auto cpp_options = std::make_unique<
::mediapipe::tasks::vision::image_embedder::ImageEmbedderOptions>();
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
CppConvertToEmbedderOptions(options.embedder_options,
&cpp_options->embedder_options);
cpp_options->running_mode = static_cast<RunningMode>(options.running_mode);
// Enable callback for processing live stream data when the running mode is
// set to RunningMode::LIVE_STREAM.
if (cpp_options->running_mode == RunningMode::LIVE_STREAM) {
if (options.result_callback == nullptr) {
const absl::Status status = absl::InvalidArgumentError(
"Provided null pointer to callback function.");
ABSL_LOG(ERROR) << "Failed to create ImageEmbedder: " << status;
CppProcessError(status, error_msg);
return nullptr;
}
ImageEmbedderOptions::result_callback_fn result_callback =
options.result_callback;
cpp_options->result_callback =
[result_callback](absl::StatusOr<CppImageEmbedderResult> cpp_result,
const Image& image, int64_t timestamp) {
char* error_msg = nullptr;
if (!cpp_result.ok()) {
ABSL_LOG(ERROR)
<< "Embedding extraction failed: " << cpp_result.status();
CppProcessError(cpp_result.status(), &error_msg);
result_callback(nullptr, MpImage(), timestamp, error_msg);
free(error_msg);
return;
}
// Result is valid for the lifetime of the callback function.
ImageEmbedderResult result;
CppConvertToEmbeddingResult(*cpp_result, &result);
const auto& image_frame = image.GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<::ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
result_callback(&result, mp_image, timestamp,
/* error_msg= */ nullptr);
CppCloseEmbeddingResult(&result);
};
}
auto embedder = ImageEmbedder::Create(std::move(cpp_options));
if (!embedder.ok()) {
ABSL_LOG(ERROR) << "Failed to create ImageEmbedder: " << embedder.status();
CppProcessError(embedder.status(), error_msg);
return nullptr;
}
return embedder->release();
}
int CppImageEmbedderEmbed(void* embedder, const MpImage* image,
ImageEmbedderResult* result, char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
const absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet.");
ABSL_LOG(ERROR) << "Embedding extraction failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
auto cpp_result = cpp_embedder->Embed(*img);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToEmbeddingResult(*cpp_result, result);
return 0;
}
int CppImageEmbedderEmbedForVideo(void* embedder, const MpImage* image,
int64_t timestamp_ms,
ImageEmbedderResult* result,
char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Embedding extraction failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
auto cpp_result = cpp_embedder->EmbedForVideo(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToEmbeddingResult(*cpp_result, result);
return 0;
}
int CppImageEmbedderEmbedAsync(void* embedder, const MpImage* image,
int64_t timestamp_ms, char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Embedding extraction failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
auto cpp_result = cpp_embedder->EmbedAsync(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Data preparation for the embedding extraction failed: "
<< cpp_result;
return CppProcessError(cpp_result, error_msg);
}
return 0;
}
void CppImageEmbedderCloseResult(ImageEmbedderResult* result) {
CppCloseEmbeddingResult(result);
}
int CppImageEmbedderClose(void* embedder, char** error_msg) {
auto cpp_embedder = static_cast<ImageEmbedder*>(embedder);
auto result = cpp_embedder->Close();
if (!result.ok()) {
ABSL_LOG(ERROR) << "Failed to close ImageEmbedder: " << result;
return CppProcessError(result, error_msg);
}
delete cpp_embedder;
return 0;
}
int CppImageEmbedderCosineSimilarity(const Embedding& u, const Embedding& v,
double* similarity, char** error_msg) {
CppEmbedding cpp_u;
CppConvertToCppEmbedding(u, &cpp_u);
CppEmbedding cpp_v;
CppConvertToCppEmbedding(v, &cpp_v);
auto status_or_similarity =
mediapipe::tasks::vision::image_embedder::ImageEmbedder::CosineSimilarity(
cpp_u, cpp_v);
if (status_or_similarity.ok()) {
*similarity = status_or_similarity.value();
} else {
ABSL_LOG(ERROR) << "Cannot compute cosine similarity.";
return CppProcessError(status_or_similarity.status(), error_msg);
}
return 0;
}
} // namespace mediapipe::tasks::c::vision::image_embedder
extern "C" {
void* image_embedder_create(struct ImageEmbedderOptions* options,
char** error_msg) {
return mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderCreate(
*options, error_msg);
}
int image_embedder_embed_image(void* embedder, const MpImage* image,
ImageEmbedderResult* result, char** error_msg) {
return mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderEmbed(
embedder, image, result, error_msg);
}
int image_embedder_embed_for_video(void* embedder, const MpImage* image,
int64_t timestamp_ms,
ImageEmbedderResult* result,
char** error_msg) {
return mediapipe::tasks::c::vision::image_embedder::
CppImageEmbedderEmbedForVideo(embedder, image, timestamp_ms, result,
error_msg);
}
int image_embedder_embed_async(void* embedder, const MpImage* image,
int64_t timestamp_ms, char** error_msg) {
return mediapipe::tasks::c::vision::image_embedder::
CppImageEmbedderEmbedAsync(embedder, image, timestamp_ms, error_msg);
}
void image_embedder_close_result(ImageEmbedderResult* result) {
mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderCloseResult(
result);
}
int image_embedder_close(void* embedder, char** error_msg) {
return mediapipe::tasks::c::vision::image_embedder::CppImageEmbedderClose(
embedder, error_msg);
}
int image_embedder_cosine_similarity(const Embedding& u, const Embedding& v,
double* similarity, char** error_msg) {
return mediapipe::tasks::c::vision::image_embedder::
CppImageEmbedderCosineSimilarity(u, v, similarity, error_msg);
}
} // extern "C"

View File

@ -0,0 +1,148 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_
#define MEDIAPIPE_TASKS_C_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_
#include <cstdint>
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
#endif // MP_EXPORT
#ifdef __cplusplus
extern "C" {
#endif
typedef EmbeddingResult ImageEmbedderResult;
// The options for configuring a MediaPipe image embedder task.
struct ImageEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
struct BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// Image embedder has three running modes:
// 1) The image mode for embedding image on single image inputs.
// 2) The video mode for embedding image on the decoded frames of a video.
// 3) The live stream mode for embedding image on the live stream of input
// data, such as from camera. In this mode, the "result_callback" below must
// be specified to receive the embedding results asynchronously.
RunningMode running_mode;
// Options for configuring the embedder behavior, such as l2_normalize and
// quantize.
struct EmbedderOptions embedder_options;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. Arguments of the callback function include:
// the pointer to embedding result, the image that result was obtained
// on, the timestamp relevant to embedding extraction results and pointer to
// error message in case of any failure. The validity of the passed arguments
// is true for the lifetime of the callback function.
//
// A caller is responsible for closing image embedder result.
typedef void (*result_callback_fn)(ImageEmbedderResult* result,
const MpImage image, int64_t timestamp_ms,
char* error_msg);
result_callback_fn result_callback;
};
// Creates an ImageEmbedder from provided `options`.
// Returns a pointer to the image embedder on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT void* image_embedder_create(struct ImageEmbedderOptions* options,
char** error_msg);
// Performs embedding extraction on the input `image`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_embedder_embed_image(void* embedder, const MpImage* image,
ImageEmbedderResult* result,
char** error_msg);
// Performs embedding extraction on the provided video frame.
// Only use this method when the ImageEmbedder is created with the video
// running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_embedder_embed_for_video(void* embedder,
const MpImage* image,
int64_t timestamp_ms,
ImageEmbedderResult* result,
char** error_msg);
// Sends live image data to embedder, and the results will be available via
// the `result_callback` provided in the ImageEmbedderOptions.
// Only use this method when the ImageEmbedder is created with the live
// stream running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the object detector. The input timestamps must be monotonically
// increasing.
// The `result_callback` provides
// - The embedding results as a `ImageEmbedderResult` object.
// - The const reference to the corresponding input image that the image
// embedder runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_embedder_embed_async(void* embedder, const MpImage* image,
int64_t timestamp_ms,
char** error_msg);
// Frees the memory allocated inside a ImageEmbedderResult result.
// Does not free the result pointer itself.
MP_EXPORT void image_embedder_close_result(ImageEmbedderResult* result);
// Frees image embedder.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_embedder_close(void* embedder, char** error_msg);
// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
MP_EXPORT int image_embedder_cosine_similarity(const Embedding& u,
const Embedding& v,
double* similarity,
char** error_msg);
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_

View File

@ -0,0 +1,302 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/image_embedder/image_embedder.h"
#include <cstdint>
#include <cstdlib>
#include <string>
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kModelName[] = "mobilenet_v3_small_100_224_embedder.tflite";
constexpr char kImageFile[] = "burger.jpg";
constexpr float kPrecision = 1e-6;
constexpr int kIterations = 100;
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
// Utility function to check the sizes, head_index and head_names of a result
// produced by kMobileNetV3Embedder.
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
EXPECT_EQ(result.embeddings_count, 1);
EXPECT_EQ(result.embeddings[0].head_index, 0);
EXPECT_EQ(std::string{result.embeddings[0].head_name}, "feature");
if (quantized) {
EXPECT_EQ(result.embeddings[0].values_count, 1024);
} else {
EXPECT_EQ(result.embeddings[0].values_count, 1024);
}
}
TEST(ImageEmbedderTest, ImageModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ImageEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* embedder_options= */
{/* l2_normalize= */ true,
/* quantize= */ false}};
void* embedder = image_embedder_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr);
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<ImageFormat>(
image->GetImageFrameSharedPtr()->Format()),
.image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
.width = image->GetImageFrameSharedPtr()->Width(),
.height = image->GetImageFrameSharedPtr()->Height()}};
ImageEmbedderResult result;
image_embedder_embed_image(embedder, &mp_image, &result,
/* error_msg */ nullptr);
CheckMobileNetV3Result(result, false);
EXPECT_NEAR(result.embeddings[0].float_embedding[0], -0.0142344, kPrecision);
image_embedder_close_result(&result);
image_embedder_close(embedder, /* error_msg */ nullptr);
}
TEST(ImageEmbedderTest, SucceedsWithCosineSimilarity) {
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
ASSERT_TRUE(image.ok());
const auto crop = DecodeImageFromFile(GetFullPath("burger_crop.jpg"));
ASSERT_TRUE(crop.ok());
const std::string model_path = GetFullPath(kModelName);
ImageEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* embedder_options= */
{/* l2_normalize= */ true,
/* quantize= */ false}};
void* embedder = image_embedder_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr);
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<ImageFormat>(
image->GetImageFrameSharedPtr()->Format()),
.image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
.width = image->GetImageFrameSharedPtr()->Width(),
.height = image->GetImageFrameSharedPtr()->Height()}};
const MpImage mp_crop = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<ImageFormat>(
crop->GetImageFrameSharedPtr()->Format()),
.image_buffer = crop->GetImageFrameSharedPtr()->PixelData(),
.width = crop->GetImageFrameSharedPtr()->Width(),
.height = crop->GetImageFrameSharedPtr()->Height()}};
// Extract both embeddings.
ImageEmbedderResult image_result;
image_embedder_embed_image(embedder, &mp_image, &image_result,
/* error_msg */ nullptr);
ImageEmbedderResult crop_result;
image_embedder_embed_image(embedder, &mp_crop, &crop_result,
/* error_msg */ nullptr);
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false);
// Check cosine similarity.
double similarity;
image_embedder_cosine_similarity(image_result.embeddings[0],
crop_result.embeddings[0], &similarity,
/* error_msg */ nullptr);
double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kPrecision);
image_embedder_close_result(&image_result);
image_embedder_close_result(&crop_result);
image_embedder_close(embedder, /* error_msg */ nullptr);
}
TEST(ImageEmbedderTest, VideoModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ImageEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::VIDEO,
/* embedder_options= */
{/* l2_normalize= */ true,
/* quantize= */ false}};
void* embedder = image_embedder_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
ImageEmbedderResult result;
image_embedder_embed_for_video(embedder, &mp_image, i, &result,
/* error_msg */ nullptr);
CheckMobileNetV3Result(result, false);
EXPECT_NEAR(result.embeddings[0].float_embedding[0], -0.0142344,
kPrecision);
image_embedder_close_result(&result);
}
image_embedder_close(embedder, /* error_msg */ nullptr);
}
// A structure to support LiveStreamModeTest below. This structure holds a
// static method `Fn` for a callback function of C API. A `static` qualifier
// allows to take an address of the method to follow API style. Another static
// struct member is `last_timestamp` that is used to verify that current
// timestamp is greater than the previous one.
struct LiveStreamModeCallback {
static int64_t last_timestamp;
static void Fn(ImageEmbedderResult* embedder_result, const MpImage image,
int64_t timestamp, char* error_msg) {
ASSERT_NE(embedder_result, nullptr);
ASSERT_EQ(error_msg, nullptr);
CheckMobileNetV3Result(*embedder_result, false);
EXPECT_NEAR(embedder_result->embeddings[0].float_embedding[0], -0.0142344,
kPrecision);
EXPECT_GT(image.image_frame.width, 0);
EXPECT_GT(image.image_frame.height, 0);
EXPECT_GT(timestamp, last_timestamp);
last_timestamp++;
}
};
int64_t LiveStreamModeCallback::last_timestamp = -1;
TEST(ImageEmbedderTest, LiveStreamModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ImageEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::LIVE_STREAM,
/* embedder_options= */
{/* l2_normalize= */ true,
/* quantize= */ false},
/* result_callback= */ LiveStreamModeCallback::Fn,
};
void* embedder = image_embedder_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
EXPECT_GE(image_embedder_embed_async(embedder, &mp_image, i,
/* error_msg */ nullptr),
0);
}
image_embedder_close(embedder, /* error_msg */ nullptr);
// Due to the flow limiter, the total of outputs might be smaller than the
// number of iterations.
EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
}
TEST(ImageEmbedderTest, InvalidArgumentHandling) {
// It is an error to set neither the asset buffer nor the path.
ImageEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr},
/* embedder_options= */ {},
};
char* error_msg;
void* embedder = image_embedder_create(&options, &error_msg);
EXPECT_EQ(embedder, nullptr);
EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify"));
free(error_msg);
}
TEST(ImageEmbedderTest, FailedEmbeddingHandling) {
const std::string model_path = GetFullPath(kModelName);
ImageEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* embedder_options= */
{/* l2_normalize= */ false,
/* quantize= */ false},
};
void* embedder = image_embedder_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr);
const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
ImageEmbedderResult result;
char* error_msg;
image_embedder_embed_image(embedder, &mp_image, &result, &error_msg);
EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet."));
free(error_msg);
image_embedder_close(embedder, /* error_msg */ nullptr);
}
} // namespace