Merge pull request #4852 from kinaryml:c-text-embedder-api

PiperOrigin-RevId: 571065228
This commit is contained in:
Copybara-Service 2023-10-05 10:43:30 -07:00
commit d686b42b85
14 changed files with 814 additions and 0 deletions

View File

@ -71,3 +71,30 @@ cc_test(
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "embedding_result",
hdrs = ["embedding_result.h"],
)
cc_library(
name = "embedding_result_converter",
srcs = ["embedding_result_converter.cc"],
hdrs = ["embedding_result_converter.h"],
deps = [
":embedding_result",
"//mediapipe/tasks/cc/components/containers:embedding_result",
],
)
cc_test(
name = "embedding_result_converter_test",
srcs = ["embedding_result_converter_test.cc"],
deps = [
":embedding_result",
":embedding_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,79 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
// Embedding result for a given embedder head.
//
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
// contain data, based on whether or not the embedder was configured to perform
// scalar quantization.
struct Embedding {
// Floating-point embedding. `nullptr` if the embedder was configured to
// perform scalar-quantization.
float* float_embedding;
// Scalar-quantized embedding. `nullptr` if the embedder was not configured to
// perform scalar quantization.
char* quantized_embedding;
// Keep the count of embedding values.
uint32_t values_count;
// The index of the embedder head (i.e. output tensor) this embedding comes
// from. This is useful for multi-head models.
int head_index;
// The optional name of the embedder head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
// Defaults to nullptr.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
char* head_name;
};
// Defines embedding results of a model.
struct EmbeddingResult {
// The embedding results for each head of the model.
struct Embedding* embeddings;
// Keep the count of embeddings.
uint32_t embeddings_count;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
int64_t timestamp_ms;
// Specifies whether the timestamp contains a valid value.
bool has_timestamp_ms;
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_

View File

@ -0,0 +1,84 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToEmbeddingResult(
const mediapipe::tasks::components::containers::EmbeddingResult& in,
EmbeddingResult* out) {
out->has_timestamp_ms = in.timestamp_ms.has_value();
out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0;
out->embeddings_count = in.embeddings.size();
out->embeddings =
out->embeddings_count ? new Embedding[out->embeddings_count] : nullptr;
for (uint32_t i = 0; i < out->embeddings_count; ++i) {
auto embedding_in = in.embeddings[i];
auto& embedding_out = out->embeddings[i];
if (!embedding_in.float_embedding.empty()) {
// Handle float embeddings
embedding_out.values_count = embedding_in.float_embedding.size();
embedding_out.float_embedding = new float[embedding_out.values_count];
std::copy(embedding_in.float_embedding.begin(),
embedding_in.float_embedding.end(),
embedding_out.float_embedding);
embedding_out.quantized_embedding = nullptr;
} else if (!embedding_in.quantized_embedding.empty()) {
// Handle quantized embeddings
embedding_out.values_count = embedding_in.quantized_embedding.size();
embedding_out.quantized_embedding = new char[embedding_out.values_count];
std::copy(embedding_in.quantized_embedding.begin(),
embedding_in.quantized_embedding.end(),
embedding_out.quantized_embedding);
embedding_out.float_embedding = nullptr;
}
embedding_out.head_index = embedding_in.head_index;
embedding_out.head_name = embedding_in.head_name.has_value()
? strdup(embedding_in.head_name->c_str())
: nullptr;
}
}
void CppCloseEmbeddingResult(EmbeddingResult* in) {
for (uint32_t i = 0; i < in->embeddings_count; ++i) {
auto embedding_in = in->embeddings[i];
delete[] embedding_in.float_embedding;
delete[] embedding_in.quantized_embedding;
embedding_in.float_embedding = nullptr;
embedding_in.quantized_embedding = nullptr;
free(embedding_in.head_name);
}
delete[] in->embeddings;
in->embeddings = nullptr;
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToEmbedding(
const mediapipe::tasks::components::containers::EmbeddingResult& in,
Embedding* out);
void CppConvertToEmbeddingResult(
const mediapipe::tasks::components::containers::EmbeddingResult& in,
EmbeddingResult* out);
void CppCloseEmbedding(Embedding* in);
void CppCloseEmbeddingResult(EmbeddingResult* in);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_

View File

@ -0,0 +1,66 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
#include <optional>
#include <string>
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe::tasks::c::components::containers {
TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) {
mediapipe::tasks::components::containers::EmbeddingResult
cpp_embedding_result = {
// Initializing embeddings vector
{// First embedding
{
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, // float embedding
{}, // quantized embedding (empty)
0, // head index
"foo" // head name
},
// Second embedding
{
{}, // float embedding (empty)
{127, 127, 127, 127, 127}, // quantized embedding
1, // head index
std::nullopt // no head name
}},
// Initializing timestamp_ms
42 // timestamp in ms
};
EmbeddingResult c_embedding_result;
CppConvertToEmbeddingResult(cpp_embedding_result, &c_embedding_result);
EXPECT_NE(c_embedding_result.embeddings, nullptr);
EXPECT_EQ(c_embedding_result.embeddings_count, 2);
EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
EXPECT_NE(c_embedding_result.embeddings[1].quantized_embedding, nullptr);
EXPECT_EQ(c_embedding_result.embeddings[1].values_count, 5);
EXPECT_EQ(c_embedding_result.embeddings[1].head_index, 1);
EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);
CppCloseEmbeddingResult(&c_embedding_result);
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -42,3 +42,30 @@ cc_test(
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "embedder_options",
hdrs = ["embedder_options.h"],
)
cc_library(
name = "embedder_options_converter",
srcs = ["embedder_options_converter.cc"],
hdrs = ["embedder_options_converter.h"],
deps = [
":embedder_options",
"//mediapipe/tasks/cc/components/processors:embedder_options",
],
)
cc_test(
name = "embedder_options_converter_test",
srcs = ["embedder_options_converter_test.cc"],
deps = [
":embedder_options",
":embedder_options_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/processors:embedder_options",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,42 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#ifdef __cplusplus
extern "C" {
#endif
// Embedder options for MediaPipe C embedding extraction tasks.
struct EmbedderOptions {
// Whether to normalize the returned feature vector with L2 norm. Use this
// option only if the model does not already contain a native L2_NORMALIZATION
// TF Lite Op. In most cases, this is already the case and L2 norm is thus
// achieved through TF Lite inference.
bool l2_normalize;
// Whether the returned embedding should be quantized to bytes via scalar
// quantization. Embeddings are implicitly assumed to be unit-norm and
// therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
// the l2_normalize option if this is not the case.
bool quantize;
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_

View File

@ -0,0 +1,28 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
namespace mediapipe::tasks::c::components::processors {
void CppConvertToEmbedderOptions(
const EmbedderOptions& in,
mediapipe::tasks::components::processors::EmbedderOptions* out) {
out->l2_normalize = in.l2_normalize;
out->quantize = in.quantize;
}
} // namespace mediapipe::tasks::c::components::processors

View File

@ -0,0 +1,30 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
namespace mediapipe::tasks::c::components::processors {
void CppConvertToEmbedderOptions(
const EmbedderOptions& in,
mediapipe::tasks::components::processors::EmbedderOptions* out);
} // namespace mediapipe::tasks::c::components::processors
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
namespace mediapipe::tasks::c::components::processors {
TEST(EmbedderOptionsConverterTest, ConvertsEmbedderOptionsCustomValues) {
EmbedderOptions c_embedder_options = {/* l2_normalize= */ true,
/* quantize= */ false};
mediapipe::tasks::components::processors::EmbedderOptions
cpp_embedder_options = {};
CppConvertToEmbedderOptions(c_embedder_options, &cpp_embedder_options);
EXPECT_EQ(cpp_embedder_options.l2_normalize, true);
EXPECT_EQ(cpp_embedder_options.quantize, false);
}
} // namespace mediapipe::tasks::c::components::processors

View File

@ -0,0 +1,85 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "text_embedder_lib",
srcs = ["text_embedder.cc"],
hdrs = ["text_embedder.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/c/components/containers:embedding_result",
"//mediapipe/tasks/c/components/containers:embedding_result_converter",
"//mediapipe/tasks/c/components/processors:embedder_options",
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/cc/text/text_embedder",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.so
cc_binary(
name = "libtext_embedder.so",
linkopts = [
"-Wl,-soname=libtext_embedder.so",
"-fvisibility=hidden",
],
linkshared = True,
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [":text_embedder_lib"],
)
# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.dylib
cc_binary(
name = "libtext_embedder.dylib",
linkopts = [
"-Wl,-install_name,libtext_embedder.dylib",
"-fvisibility=hidden",
],
linkshared = True,
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [":text_embedder_lib"],
)
cc_test(
name = "text_embedder_test",
srcs = ["text_embedder_test.cc"],
data = ["//mediapipe/tasks/testdata/text:mobilebert_embedding_model"],
linkstatic = 1,
deps = [
":text_embedder_lib",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,119 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
#include <memory>
#include <utility>
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
namespace mediapipe::tasks::c::text::text_embedder {
namespace {
using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
using ::mediapipe::tasks::c::components::containers::
CppConvertToEmbeddingResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToEmbedderOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::text::text_embedder::TextEmbedder;
int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) {
*error_msg = strdup(status.ToString().c_str());
}
return status.raw_code();
}
} // namespace
TextEmbedder* CppTextEmbedderCreate(const TextEmbedderOptions& options,
char** error_msg) {
auto cpp_options = std::make_unique<
::mediapipe::tasks::text::text_embedder::TextEmbedderOptions>();
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
CppConvertToEmbedderOptions(options.embedder_options,
&cpp_options->embedder_options);
auto embedder = TextEmbedder::Create(std::move(cpp_options));
if (!embedder.ok()) {
ABSL_LOG(ERROR) << "Failed to create TextEmbedder: " << embedder.status();
CppProcessError(embedder.status(), error_msg);
return nullptr;
}
return embedder->release();
}
int CppTextEmbedderEmbed(void* embedder, const char* utf8_str,
TextEmbedderResult* result, char** error_msg) {
auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
auto cpp_result = cpp_embedder->Embed(utf8_str);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToEmbeddingResult(*cpp_result, result);
return 0;
}
void CppTextEmbedderCloseResult(TextEmbedderResult* result) {
CppCloseEmbeddingResult(result);
}
int CppTextEmbedderClose(void* embedder, char** error_msg) {
auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
auto result = cpp_embedder->Close();
if (!result.ok()) {
ABSL_LOG(ERROR) << "Failed to close TextEmbedder: " << result;
return CppProcessError(result, error_msg);
}
delete cpp_embedder;
return 0;
}
} // namespace mediapipe::tasks::c::text::text_embedder
extern "C" {
void* text_embedder_create(struct TextEmbedderOptions* options,
char** error_msg) {
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
*options, error_msg);
}
int text_embedder_embed(void* embedder, const char* utf8_str,
TextEmbedderResult* result, char** error_msg) {
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
embedder, utf8_str, result, error_msg);
}
void text_embedder_close_result(TextEmbedderResult* result) {
mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCloseResult(result);
}
int text_embedder_close(void* embedder, char** error_ms) {
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderClose(
embedder, error_ms);
}
} // extern "C"

View File

@ -0,0 +1,74 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
#define MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
#endif // MP_EXPORT
#ifdef __cplusplus
extern "C" {
#endif
typedef struct EmbeddingResult TextEmbedderResult;
// The options for configuring a MediaPipe text embedder task.
struct TextEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
struct BaseOptions base_options;
// Options for configuring the embedder behavior, such as l2_normalize
// and quantize.
struct EmbedderOptions embedder_options;
};
// Creates a TextEmbedder from the provided `options`.
// Returns a pointer to the text embedder on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
char** error_msg = nullptr);
// Performs embedding extraction on the input `text`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
TextEmbedderResult* result,
char** error_msg = nullptr);
// Frees the memory allocated inside a TextEmbedderResult result. Does not
// free the result pointer itself.
MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
// Shuts down the TextEmbedder when all the work is done. Frees all memory.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int text_embedder_close(void* embedder, char** error_msg = nullptr);
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_

View File

@ -0,0 +1,79 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
#include <cstdlib>
#include <string>
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace {
using ::mediapipe::file::JoinPath;
using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] =
"mobilebert_embedding_with_metadata.tflite";
constexpr char kTestString[] = "It's beautiful outside.";
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
TEST(TextEmbedderTest, SmokeTest) {
std::string model_path = GetFullPath(kTestBertModelPath);
TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ model_path.c_str()},
/* embedder_options= */
{/* l2_normalize= */ false, /* quantize= */ true},
};
void* embedder = text_embedder_create(&options);
EXPECT_NE(embedder, nullptr);
TextEmbedderResult result;
text_embedder_embed(embedder, kTestString, &result);
EXPECT_EQ(result.embeddings_count, 1);
EXPECT_EQ(result.embeddings[0].values_count, 512);
text_embedder_close_result(&result);
text_embedder_close(embedder);
}
TEST(TextEmbedderTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path.
TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ nullptr},
/* embedder_options= */ {},
};
char* error_msg;
void* embedder = text_embedder_create(&options, &error_msg);
EXPECT_EQ(embedder, nullptr);
EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT"));
free(error_msg);
}
} // namespace