diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD index fd00261f2..0d89c820e 100644 --- a/mediapipe/tasks/c/components/containers/BUILD +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -71,3 +71,30 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "embedding_result", + hdrs = ["embedding_result.h"], +) + +cc_library( + name = "embedding_result_converter", + srcs = ["embedding_result_converter.cc"], + hdrs = ["embedding_result_converter.h"], + deps = [ + ":embedding_result", + "//mediapipe/tasks/cc/components/containers:embedding_result", + ], +) + +cc_test( + name = "embedding_result_converter_test", + srcs = ["embedding_result_converter_test.cc"], + deps = [ + ":embedding_result", + ":embedding_result_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/components/containers/embedding_result.h b/mediapipe/tasks/c/components/containers/embedding_result.h new file mode 100644 index 000000000..62735628a --- /dev/null +++ b/mediapipe/tasks/c/components/containers/embedding_result.h @@ -0,0 +1,79 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Embedding result for a given embedder head. +// +// One and only one of the two 'float_embedding' and 'quantized_embedding' will +// contain data, based on whether or not the embedder was configured to perform +// scalar quantization. +struct Embedding { + // Floating-point embedding. `nullptr` if the embedder was configured to + // perform scalar-quantization. + float* float_embedding; + + // Scalar-quantized embedding. `nullptr` if the embedder was not configured to + // perform scalar quantization. + char* quantized_embedding; + + // Keep the count of embedding values. + uint32_t values_count; + + // The index of the embedder head (i.e. output tensor) this embedding comes + // from. This is useful for multi-head models. + int head_index; + + // The optional name of the embedder head, as provided in the TFLite Model + // Metadata [1] if present. This is useful for multi-head models. + // Defaults to nullptr. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + char* head_name; +}; + +// Defines embedding results of a model. +struct EmbeddingResult { + // The embedding results for each head of the model. + struct Embedding* embeddings; + + // Keep the count of embeddings. + uint32_t embeddings_count; + + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for embedding extraction on time series (e.g. audio + // embedding). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + int64_t timestamp_ms; + + // Specifies whether the timestamp contains a valid value. + bool has_timestamp_ms; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_ diff --git a/mediapipe/tasks/c/components/containers/embedding_result_converter.cc b/mediapipe/tasks/c/components/containers/embedding_result_converter.cc new file mode 100644 index 000000000..ba72c0994 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/embedding_result_converter.cc @@ -0,0 +1,84 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h" + +#include +#include +#include + +#include "mediapipe/tasks/c/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToEmbeddingResult( + const mediapipe::tasks::components::containers::EmbeddingResult& in, + EmbeddingResult* out) { + out->has_timestamp_ms = in.timestamp_ms.has_value(); + out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0; + + out->embeddings_count = in.embeddings.size(); + out->embeddings = + out->embeddings_count ? new Embedding[out->embeddings_count] : nullptr; + + for (uint32_t i = 0; i < out->embeddings_count; ++i) { + auto embedding_in = in.embeddings[i]; + auto& embedding_out = out->embeddings[i]; + if (!embedding_in.float_embedding.empty()) { + // Handle float embeddings + embedding_out.values_count = embedding_in.float_embedding.size(); + embedding_out.float_embedding = new float[embedding_out.values_count]; + + std::copy(embedding_in.float_embedding.begin(), + embedding_in.float_embedding.end(), + embedding_out.float_embedding); + + embedding_out.quantized_embedding = nullptr; + } else if (!embedding_in.quantized_embedding.empty()) { + // Handle quantized embeddings + embedding_out.values_count = embedding_in.quantized_embedding.size(); + embedding_out.quantized_embedding = new char[embedding_out.values_count]; + + std::copy(embedding_in.quantized_embedding.begin(), + embedding_in.quantized_embedding.end(), + embedding_out.quantized_embedding); + + embedding_out.float_embedding = nullptr; + } + + embedding_out.head_index = embedding_in.head_index; + embedding_out.head_name = embedding_in.head_name.has_value() + ? strdup(embedding_in.head_name->c_str()) + : nullptr; + } +} + +void CppCloseEmbeddingResult(EmbeddingResult* in) { + for (uint32_t i = 0; i < in->embeddings_count; ++i) { + auto embedding_in = in->embeddings[i]; + + delete[] embedding_in.float_embedding; + delete[] embedding_in.quantized_embedding; + embedding_in.float_embedding = nullptr; + embedding_in.quantized_embedding = nullptr; + + free(embedding_in.head_name); + } + delete[] in->embeddings; + in->embeddings = nullptr; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/embedding_result_converter.h b/mediapipe/tasks/c/components/containers/embedding_result_converter.h new file mode 100644 index 000000000..15bcdbdd0 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/embedding_result_converter.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToEmbedding( + const mediapipe::tasks::components::containers::EmbeddingResult& in, + Embedding* out); + +void CppConvertToEmbeddingResult( + const mediapipe::tasks::components::containers::EmbeddingResult& in, + EmbeddingResult* out); + +void CppCloseEmbedding(Embedding* in); + +void CppCloseEmbeddingResult(EmbeddingResult* in); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/embedding_result_converter_test.cc b/mediapipe/tasks/c/components/containers/embedding_result_converter_test.cc new file mode 100644 index 000000000..0d8f6545a --- /dev/null +++ b/mediapipe/tasks/c/components/containers/embedding_result_converter_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h" + +#include +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) { + mediapipe::tasks::components::containers::EmbeddingResult + cpp_embedding_result = { + // Initializing embeddings vector + {// First embedding + { + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, // float embedding + {}, // quantized embedding (empty) + 0, // head index + "foo" // head name + }, + // Second embedding + { + {}, // float embedding (empty) + {127, 127, 127, 127, 127}, // quantized embedding + 1, // head index + std::nullopt // no head name + }}, + // Initializing timestamp_ms + 42 // timestamp in ms + }; + + EmbeddingResult c_embedding_result; + CppConvertToEmbeddingResult(cpp_embedding_result, &c_embedding_result); + EXPECT_NE(c_embedding_result.embeddings, nullptr); + EXPECT_EQ(c_embedding_result.embeddings_count, 2); + EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr); + EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5); + EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0); + EXPECT_NE(c_embedding_result.embeddings[1].quantized_embedding, nullptr); + EXPECT_EQ(c_embedding_result.embeddings[1].values_count, 5); + EXPECT_EQ(c_embedding_result.embeddings[1].head_index, 1); + EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo"); + EXPECT_EQ(c_embedding_result.timestamp_ms, 42); + EXPECT_EQ(c_embedding_result.has_timestamp_ms, true); + + CppCloseEmbeddingResult(&c_embedding_result); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/processors/BUILD b/mediapipe/tasks/c/components/processors/BUILD index 5794769d2..dbc7c82da 100644 --- a/mediapipe/tasks/c/components/processors/BUILD +++ b/mediapipe/tasks/c/components/processors/BUILD @@ -42,3 +42,30 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "embedder_options", + hdrs = ["embedder_options.h"], +) + +cc_library( + name = "embedder_options_converter", + srcs = ["embedder_options_converter.cc"], + hdrs = ["embedder_options_converter.h"], + deps = [ + ":embedder_options", + "//mediapipe/tasks/cc/components/processors:embedder_options", + ], +) + +cc_test( + name = "embedder_options_converter_test", + srcs = ["embedder_options_converter_test.cc"], + deps = [ + ":embedder_options", + ":embedder_options_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/components/processors/embedder_options.h b/mediapipe/tasks/c/components/processors/embedder_options.h new file mode 100644 index 000000000..99466dcb0 --- /dev/null +++ b/mediapipe/tasks/c/components/processors/embedder_options.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// Embedder options for MediaPipe C embedding extraction tasks. +struct EmbedderOptions { + // Whether to normalize the returned feature vector with L2 norm. Use this + // option only if the model does not already contain a native L2_NORMALIZATION + // TF Lite Op. In most cases, this is already the case and L2 norm is thus + // achieved through TF Lite inference. + bool l2_normalize; + + // Whether the returned embedding should be quantized to bytes via scalar + // quantization. Embeddings are implicitly assumed to be unit-norm and + // therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + // the l2_normalize option if this is not the case. + bool quantize; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_ diff --git a/mediapipe/tasks/c/components/processors/embedder_options_converter.cc b/mediapipe/tasks/c/components/processors/embedder_options_converter.cc new file mode 100644 index 000000000..db1ac87ab --- /dev/null +++ b/mediapipe/tasks/c/components/processors/embedder_options_converter.cc @@ -0,0 +1,28 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" + +namespace mediapipe::tasks::c::components::processors { + +void CppConvertToEmbedderOptions( + const EmbedderOptions& in, + mediapipe::tasks::components::processors::EmbedderOptions* out) { + out->l2_normalize = in.l2_normalize; + out->quantize = in.quantize; +} + +} // namespace mediapipe::tasks::c::components::processors diff --git a/mediapipe/tasks/c/components/processors/embedder_options_converter.h b/mediapipe/tasks/c/components/processors/embedder_options_converter.h new file mode 100644 index 000000000..16b3c52ee --- /dev/null +++ b/mediapipe/tasks/c/components/processors/embedder_options_converter.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" + +namespace mediapipe::tasks::c::components::processors { + +void CppConvertToEmbedderOptions( + const EmbedderOptions& in, + mediapipe::tasks::components::processors::EmbedderOptions* out); + +} // namespace mediapipe::tasks::c::components::processors + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/processors/embedder_options_converter_test.cc b/mediapipe/tasks/c/components/processors/embedder_options_converter_test.cc new file mode 100644 index 000000000..34187aaee --- /dev/null +++ b/mediapipe/tasks/c/components/processors/embedder_options_converter_test.cc @@ -0,0 +1,36 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h" + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" + +namespace mediapipe::tasks::c::components::processors { + +TEST(EmbedderOptionsConverterTest, ConvertsEmbedderOptionsCustomValues) { + EmbedderOptions c_embedder_options = {/* l2_normalize= */ true, + /* quantize= */ false}; + + mediapipe::tasks::components::processors::EmbedderOptions + cpp_embedder_options = {}; + + CppConvertToEmbedderOptions(c_embedder_options, &cpp_embedder_options); + EXPECT_EQ(cpp_embedder_options.l2_normalize, true); + EXPECT_EQ(cpp_embedder_options.quantize, false); +} + +} // namespace mediapipe::tasks::c::components::processors diff --git a/mediapipe/tasks/c/text/text_embedder/BUILD b/mediapipe/tasks/c/text/text_embedder/BUILD new file mode 100644 index 000000000..28a743eb8 --- /dev/null +++ b/mediapipe/tasks/c/text/text_embedder/BUILD @@ -0,0 +1,85 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_embedder_lib", + srcs = ["text_embedder.cc"], + hdrs = ["text_embedder.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/c/components/containers:embedding_result", + "//mediapipe/tasks/c/components/containers:embedding_result_converter", + "//mediapipe/tasks/c/components/processors:embedder_options", + "//mediapipe/tasks/c/components/processors:embedder_options_converter", + "//mediapipe/tasks/c/core:base_options", + "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/cc/text/text_embedder", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \ +# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.so +cc_binary( + name = "libtext_embedder.so", + linkopts = [ + "-Wl,-soname=libtext_embedder.so", + "-fvisibility=hidden", + ], + linkshared = True, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [":text_embedder_lib"], +) + +# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \ +# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.dylib +cc_binary( + name = "libtext_embedder.dylib", + linkopts = [ + "-Wl,-install_name,libtext_embedder.dylib", + "-fvisibility=hidden", + ], + linkshared = True, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [":text_embedder_lib"], +) + +cc_test( + name = "text_embedder_test", + srcs = ["text_embedder_test.cc"], + data = ["//mediapipe/tasks/testdata/text:mobilebert_embedding_model"], + linkstatic = 1, + deps = [ + ":text_embedder_lib", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/text/text_embedder/text_embedder.cc b/mediapipe/tasks/c/text/text_embedder/text_embedder.cc new file mode 100644 index 000000000..c98b958f5 --- /dev/null +++ b/mediapipe/tasks/c/text/text_embedder/text_embedder.cc @@ -0,0 +1,119 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h" + +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h" +#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h" +#include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h" + +namespace mediapipe::tasks::c::text::text_embedder { + +namespace { + +using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult; +using ::mediapipe::tasks::c::components::containers:: + CppConvertToEmbeddingResult; +using ::mediapipe::tasks::c::components::processors:: + CppConvertToEmbedderOptions; +using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; +using ::mediapipe::tasks::text::text_embedder::TextEmbedder; + +int CppProcessError(absl::Status status, char** error_msg) { + if (error_msg) { + *error_msg = strdup(status.ToString().c_str()); + } + return status.raw_code(); +} + +} // namespace + +TextEmbedder* CppTextEmbedderCreate(const TextEmbedderOptions& options, + char** error_msg) { + auto cpp_options = std::make_unique< + ::mediapipe::tasks::text::text_embedder::TextEmbedderOptions>(); + + CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); + CppConvertToEmbedderOptions(options.embedder_options, + &cpp_options->embedder_options); + + auto embedder = TextEmbedder::Create(std::move(cpp_options)); + if (!embedder.ok()) { + ABSL_LOG(ERROR) << "Failed to create TextEmbedder: " << embedder.status(); + CppProcessError(embedder.status(), error_msg); + return nullptr; + } + return embedder->release(); +} + +int CppTextEmbedderEmbed(void* embedder, const char* utf8_str, + TextEmbedderResult* result, char** error_msg) { + auto cpp_embedder = static_cast(embedder); + auto cpp_result = cpp_embedder->Embed(utf8_str); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + CppConvertToEmbeddingResult(*cpp_result, result); + return 0; +} + +void CppTextEmbedderCloseResult(TextEmbedderResult* result) { + CppCloseEmbeddingResult(result); +} + +int CppTextEmbedderClose(void* embedder, char** error_msg) { + auto cpp_embedder = static_cast(embedder); + auto result = cpp_embedder->Close(); + if (!result.ok()) { + ABSL_LOG(ERROR) << "Failed to close TextEmbedder: " << result; + return CppProcessError(result, error_msg); + } + delete cpp_embedder; + return 0; +} + +} // namespace mediapipe::tasks::c::text::text_embedder + +extern "C" { + +void* text_embedder_create(struct TextEmbedderOptions* options, + char** error_msg) { + return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate( + *options, error_msg); +} + +int text_embedder_embed(void* embedder, const char* utf8_str, + TextEmbedderResult* result, char** error_msg) { + return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed( + embedder, utf8_str, result, error_msg); +} + +void text_embedder_close_result(TextEmbedderResult* result) { + mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCloseResult(result); +} + +int text_embedder_close(void* embedder, char** error_ms) { + return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderClose( + embedder, error_ms); +} + +} // extern "C" diff --git a/mediapipe/tasks/c/text/text_embedder/text_embedder.h b/mediapipe/tasks/c/text/text_embedder/text_embedder.h new file mode 100644 index 000000000..c9ccf816b --- /dev/null +++ b/mediapipe/tasks/c/text/text_embedder/text_embedder.h @@ -0,0 +1,74 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ +#define MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ + +#include "mediapipe/tasks/c/components/containers/embedding_result.h" +#include "mediapipe/tasks/c/components/processors/embedder_options.h" +#include "mediapipe/tasks/c/core/base_options.h" + +#ifndef MP_EXPORT +#define MP_EXPORT __attribute__((visibility("default"))) +#endif // MP_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct EmbeddingResult TextEmbedderResult; + +// The options for configuring a MediaPipe text embedder task. +struct TextEmbedderOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // Options for configuring the embedder behavior, such as l2_normalize + // and quantize. + struct EmbedderOptions embedder_options; +}; + +// Creates a TextEmbedder from the provided `options`. +// Returns a pointer to the text embedder on success. +// If an error occurs, returns `nullptr` and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options, + char** error_msg = nullptr); + +// Performs embedding extraction on the input `text`. Returns `0` on success. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str, + TextEmbedderResult* result, + char** error_msg = nullptr); + +// Frees the memory allocated inside a TextEmbedderResult result. Does not +// free the result pointer itself. +MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result); + +// Shuts down the TextEmbedder when all the work is done. Frees all memory. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int text_embedder_close(void* embedder, char** error_msg = nullptr); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ diff --git a/mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc new file mode 100644 index 000000000..c823e01b4 --- /dev/null +++ b/mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace { + +using ::mediapipe::file::JoinPath; +using testing::HasSubstr; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kTestBertModelPath[] = + "mobilebert_embedding_with_metadata.tflite"; +constexpr char kTestString[] = "It's beautiful outside."; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +TEST(TextEmbedderTest, SmokeTest) { + std::string model_path = GetFullPath(kTestBertModelPath); + TextEmbedderOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* embedder_options= */ + {/* l2_normalize= */ false, /* quantize= */ true}, + }; + + void* embedder = text_embedder_create(&options); + EXPECT_NE(embedder, nullptr); + + TextEmbedderResult result; + text_embedder_embed(embedder, kTestString, &result); + EXPECT_EQ(result.embeddings_count, 1); + EXPECT_EQ(result.embeddings[0].values_count, 512); + + text_embedder_close_result(&result); + text_embedder_close(embedder); +} + +TEST(TextEmbedderTest, ErrorHandling) { + // It is an error to set neither the asset buffer nor the path. + TextEmbedderOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ nullptr}, + /* embedder_options= */ {}, + }; + + char* error_msg; + void* embedder = text_embedder_create(&options, &error_msg); + EXPECT_EQ(embedder, nullptr); + + EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT")); + + free(error_msg); +} + +} // namespace