Merge pull request #4852 from kinaryml:c-text-embedder-api
PiperOrigin-RevId: 571065228
This commit is contained in:
commit
d686b42b85
|
@ -71,3 +71,30 @@ cc_test(
|
|||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_result",
|
||||
hdrs = ["embedding_result.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_result_converter",
|
||||
srcs = ["embedding_result_converter.cc"],
|
||||
hdrs = ["embedding_result_converter.h"],
|
||||
deps = [
|
||||
":embedding_result",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "embedding_result_converter_test",
|
||||
srcs = ["embedding_result_converter_test.cc"],
|
||||
deps = [
|
||||
":embedding_result",
|
||||
":embedding_result_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
79
mediapipe/tasks/c/components/containers/embedding_result.h
Normal file
79
mediapipe/tasks/c/components/containers/embedding_result.h
Normal file
|
@ -0,0 +1,79 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Embedding result for a given embedder head.
|
||||
//
|
||||
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
|
||||
// contain data, based on whether or not the embedder was configured to perform
|
||||
// scalar quantization.
|
||||
struct Embedding {
|
||||
// Floating-point embedding. `nullptr` if the embedder was configured to
|
||||
// perform scalar-quantization.
|
||||
float* float_embedding;
|
||||
|
||||
// Scalar-quantized embedding. `nullptr` if the embedder was not configured to
|
||||
// perform scalar quantization.
|
||||
char* quantized_embedding;
|
||||
|
||||
// Keep the count of embedding values.
|
||||
uint32_t values_count;
|
||||
|
||||
// The index of the embedder head (i.e. output tensor) this embedding comes
|
||||
// from. This is useful for multi-head models.
|
||||
int head_index;
|
||||
|
||||
// The optional name of the embedder head, as provided in the TFLite Model
|
||||
// Metadata [1] if present. This is useful for multi-head models.
|
||||
// Defaults to nullptr.
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
char* head_name;
|
||||
};
|
||||
|
||||
// Defines embedding results of a model.
|
||||
struct EmbeddingResult {
|
||||
// The embedding results for each head of the model.
|
||||
struct Embedding* embeddings;
|
||||
|
||||
// Keep the count of embeddings.
|
||||
uint32_t embeddings_count;
|
||||
|
||||
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
// corresponding to these results.
|
||||
//
|
||||
// This is only used for embedding extraction on time series (e.g. audio
|
||||
// embedding). In these use cases, the amount of data to process might
|
||||
// exceed the maximum size that the model can process: to solve this, the
|
||||
// input data is split into multiple chunks starting at different timestamps.
|
||||
int64_t timestamp_ms;
|
||||
|
||||
// Specifies whether the timestamp contains a valid value.
|
||||
bool has_timestamp_ms;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
|
@ -0,0 +1,84 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToEmbeddingResult(
|
||||
const mediapipe::tasks::components::containers::EmbeddingResult& in,
|
||||
EmbeddingResult* out) {
|
||||
out->has_timestamp_ms = in.timestamp_ms.has_value();
|
||||
out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0;
|
||||
|
||||
out->embeddings_count = in.embeddings.size();
|
||||
out->embeddings =
|
||||
out->embeddings_count ? new Embedding[out->embeddings_count] : nullptr;
|
||||
|
||||
for (uint32_t i = 0; i < out->embeddings_count; ++i) {
|
||||
auto embedding_in = in.embeddings[i];
|
||||
auto& embedding_out = out->embeddings[i];
|
||||
if (!embedding_in.float_embedding.empty()) {
|
||||
// Handle float embeddings
|
||||
embedding_out.values_count = embedding_in.float_embedding.size();
|
||||
embedding_out.float_embedding = new float[embedding_out.values_count];
|
||||
|
||||
std::copy(embedding_in.float_embedding.begin(),
|
||||
embedding_in.float_embedding.end(),
|
||||
embedding_out.float_embedding);
|
||||
|
||||
embedding_out.quantized_embedding = nullptr;
|
||||
} else if (!embedding_in.quantized_embedding.empty()) {
|
||||
// Handle quantized embeddings
|
||||
embedding_out.values_count = embedding_in.quantized_embedding.size();
|
||||
embedding_out.quantized_embedding = new char[embedding_out.values_count];
|
||||
|
||||
std::copy(embedding_in.quantized_embedding.begin(),
|
||||
embedding_in.quantized_embedding.end(),
|
||||
embedding_out.quantized_embedding);
|
||||
|
||||
embedding_out.float_embedding = nullptr;
|
||||
}
|
||||
|
||||
embedding_out.head_index = embedding_in.head_index;
|
||||
embedding_out.head_name = embedding_in.head_name.has_value()
|
||||
? strdup(embedding_in.head_name->c_str())
|
||||
: nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void CppCloseEmbeddingResult(EmbeddingResult* in) {
|
||||
for (uint32_t i = 0; i < in->embeddings_count; ++i) {
|
||||
auto embedding_in = in->embeddings[i];
|
||||
|
||||
delete[] embedding_in.float_embedding;
|
||||
delete[] embedding_in.quantized_embedding;
|
||||
embedding_in.float_embedding = nullptr;
|
||||
embedding_in.quantized_embedding = nullptr;
|
||||
|
||||
free(embedding_in.head_name);
|
||||
}
|
||||
delete[] in->embeddings;
|
||||
in->embeddings = nullptr;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToEmbedding(
|
||||
const mediapipe::tasks::components::containers::EmbeddingResult& in,
|
||||
Embedding* out);
|
||||
|
||||
void CppConvertToEmbeddingResult(
|
||||
const mediapipe::tasks::components::containers::EmbeddingResult& in,
|
||||
EmbeddingResult* out);
|
||||
|
||||
void CppCloseEmbedding(Embedding* in);
|
||||
|
||||
void CppCloseEmbeddingResult(EmbeddingResult* in);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
|
|
@ -0,0 +1,66 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) {
|
||||
mediapipe::tasks::components::containers::EmbeddingResult
|
||||
cpp_embedding_result = {
|
||||
// Initializing embeddings vector
|
||||
{// First embedding
|
||||
{
|
||||
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, // float embedding
|
||||
{}, // quantized embedding (empty)
|
||||
0, // head index
|
||||
"foo" // head name
|
||||
},
|
||||
// Second embedding
|
||||
{
|
||||
{}, // float embedding (empty)
|
||||
{127, 127, 127, 127, 127}, // quantized embedding
|
||||
1, // head index
|
||||
std::nullopt // no head name
|
||||
}},
|
||||
// Initializing timestamp_ms
|
||||
42 // timestamp in ms
|
||||
};
|
||||
|
||||
EmbeddingResult c_embedding_result;
|
||||
CppConvertToEmbeddingResult(cpp_embedding_result, &c_embedding_result);
|
||||
EXPECT_NE(c_embedding_result.embeddings, nullptr);
|
||||
EXPECT_EQ(c_embedding_result.embeddings_count, 2);
|
||||
EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
|
||||
EXPECT_NE(c_embedding_result.embeddings[1].quantized_embedding, nullptr);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[1].values_count, 5);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[1].head_index, 1);
|
||||
EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
|
||||
EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
|
||||
EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);
|
||||
|
||||
CppCloseEmbeddingResult(&c_embedding_result);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -42,3 +42,30 @@ cc_test(
|
|||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedder_options",
|
||||
hdrs = ["embedder_options.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedder_options_converter",
|
||||
srcs = ["embedder_options_converter.cc"],
|
||||
hdrs = ["embedder_options_converter.h"],
|
||||
deps = [
|
||||
":embedder_options",
|
||||
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "embedder_options_converter_test",
|
||||
srcs = ["embedder_options_converter_test.cc"],
|
||||
deps = [
|
||||
":embedder_options",
|
||||
":embedder_options_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
42
mediapipe/tasks/c/components/processors/embedder_options.h
Normal file
42
mediapipe/tasks/c/components/processors/embedder_options.h
Normal file
|
@ -0,0 +1,42 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Embedder options for MediaPipe C embedding extraction tasks.
|
||||
struct EmbedderOptions {
|
||||
// Whether to normalize the returned feature vector with L2 norm. Use this
|
||||
// option only if the model does not already contain a native L2_NORMALIZATION
|
||||
// TF Lite Op. In most cases, this is already the case and L2 norm is thus
|
||||
// achieved through TF Lite inference.
|
||||
bool l2_normalize;
|
||||
|
||||
// Whether the returned embedding should be quantized to bytes via scalar
|
||||
// quantization. Embeddings are implicitly assumed to be unit-norm and
|
||||
// therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
|
||||
// the l2_normalize option if this is not the case.
|
||||
bool quantize;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::processors {
|
||||
|
||||
void CppConvertToEmbedderOptions(
|
||||
const EmbedderOptions& in,
|
||||
mediapipe::tasks::components::processors::EmbedderOptions* out) {
|
||||
out->l2_normalize = in.l2_normalize;
|
||||
out->quantize = in.quantize;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::processors
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
|
||||
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::processors {
|
||||
|
||||
void CppConvertToEmbedderOptions(
|
||||
const EmbedderOptions& in,
|
||||
mediapipe::tasks::components::processors::EmbedderOptions* out);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::processors
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
|
|
@ -0,0 +1,36 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::processors {
|
||||
|
||||
TEST(EmbedderOptionsConverterTest, ConvertsEmbedderOptionsCustomValues) {
|
||||
EmbedderOptions c_embedder_options = {/* l2_normalize= */ true,
|
||||
/* quantize= */ false};
|
||||
|
||||
mediapipe::tasks::components::processors::EmbedderOptions
|
||||
cpp_embedder_options = {};
|
||||
|
||||
CppConvertToEmbedderOptions(c_embedder_options, &cpp_embedder_options);
|
||||
EXPECT_EQ(cpp_embedder_options.l2_normalize, true);
|
||||
EXPECT_EQ(cpp_embedder_options.quantize, false);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::processors
|
85
mediapipe/tasks/c/text/text_embedder/BUILD
Normal file
85
mediapipe/tasks/c/text/text_embedder/BUILD
Normal file
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "text_embedder_lib",
|
||||
srcs = ["text_embedder.cc"],
|
||||
hdrs = ["text_embedder.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/c/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/c/components/containers:embedding_result_converter",
|
||||
"//mediapipe/tasks/c/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
"//mediapipe/tasks/c/core:base_options_converter",
|
||||
"//mediapipe/tasks/cc/text/text_embedder",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
|
||||
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.so
|
||||
cc_binary(
|
||||
name = "libtext_embedder.so",
|
||||
linkopts = [
|
||||
"-Wl,-soname=libtext_embedder.so",
|
||||
"-fvisibility=hidden",
|
||||
],
|
||||
linkshared = True,
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
deps = [":text_embedder_lib"],
|
||||
)
|
||||
|
||||
# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
|
||||
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.dylib
|
||||
cc_binary(
|
||||
name = "libtext_embedder.dylib",
|
||||
linkopts = [
|
||||
"-Wl,-install_name,libtext_embedder.dylib",
|
||||
"-fvisibility=hidden",
|
||||
],
|
||||
linkshared = True,
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
deps = [":text_embedder_lib"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_embedder_test",
|
||||
srcs = ["text_embedder_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:mobilebert_embedding_model"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":text_embedder_lib",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
119
mediapipe/tasks/c/text/text_embedder/text_embedder.cc
Normal file
119
mediapipe/tasks/c/text/text_embedder/text_embedder.cc
Normal file
|
@ -0,0 +1,119 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
|
||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
|
||||
|
||||
namespace mediapipe::tasks::c::text::text_embedder {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppConvertToEmbeddingResult;
|
||||
using ::mediapipe::tasks::c::components::processors::
|
||||
CppConvertToEmbedderOptions;
|
||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::text::text_embedder::TextEmbedder;
|
||||
|
||||
int CppProcessError(absl::Status status, char** error_msg) {
|
||||
if (error_msg) {
|
||||
*error_msg = strdup(status.ToString().c_str());
|
||||
}
|
||||
return status.raw_code();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TextEmbedder* CppTextEmbedderCreate(const TextEmbedderOptions& options,
|
||||
char** error_msg) {
|
||||
auto cpp_options = std::make_unique<
|
||||
::mediapipe::tasks::text::text_embedder::TextEmbedderOptions>();
|
||||
|
||||
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
|
||||
CppConvertToEmbedderOptions(options.embedder_options,
|
||||
&cpp_options->embedder_options);
|
||||
|
||||
auto embedder = TextEmbedder::Create(std::move(cpp_options));
|
||||
if (!embedder.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create TextEmbedder: " << embedder.status();
|
||||
CppProcessError(embedder.status(), error_msg);
|
||||
return nullptr;
|
||||
}
|
||||
return embedder->release();
|
||||
}
|
||||
|
||||
int CppTextEmbedderEmbed(void* embedder, const char* utf8_str,
|
||||
TextEmbedderResult* result, char** error_msg) {
|
||||
auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
|
||||
auto cpp_result = cpp_embedder->Embed(utf8_str);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
|
||||
return CppProcessError(cpp_result.status(), error_msg);
|
||||
}
|
||||
CppConvertToEmbeddingResult(*cpp_result, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CppTextEmbedderCloseResult(TextEmbedderResult* result) {
|
||||
CppCloseEmbeddingResult(result);
|
||||
}
|
||||
|
||||
int CppTextEmbedderClose(void* embedder, char** error_msg) {
|
||||
auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
|
||||
auto result = cpp_embedder->Close();
|
||||
if (!result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to close TextEmbedder: " << result;
|
||||
return CppProcessError(result, error_msg);
|
||||
}
|
||||
delete cpp_embedder;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::text::text_embedder
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
|
||||
*options, error_msg);
|
||||
}
|
||||
|
||||
int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||
TextEmbedderResult* result, char** error_msg) {
|
||||
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
|
||||
embedder, utf8_str, result, error_msg);
|
||||
}
|
||||
|
||||
void text_embedder_close_result(TextEmbedderResult* result) {
|
||||
mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCloseResult(result);
|
||||
}
|
||||
|
||||
int text_embedder_close(void* embedder, char** error_ms) {
|
||||
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderClose(
|
||||
embedder, error_ms);
|
||||
}
|
||||
|
||||
} // extern "C"
|
74
mediapipe/tasks/c/text/text_embedder/text_embedder.h
Normal file
74
mediapipe/tasks/c/text/text_embedder/text_embedder.h
Normal file
|
@ -0,0 +1,74 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||
#define MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
|
||||
#ifndef MP_EXPORT
|
||||
#define MP_EXPORT __attribute__((visibility("default")))
|
||||
#endif // MP_EXPORT
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct EmbeddingResult TextEmbedderResult;
|
||||
|
||||
// The options for configuring a MediaPipe text embedder task.
|
||||
struct TextEmbedderOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
struct BaseOptions base_options;
|
||||
|
||||
// Options for configuring the embedder behavior, such as l2_normalize
|
||||
// and quantize.
|
||||
struct EmbedderOptions embedder_options;
|
||||
};
|
||||
|
||||
// Creates a TextEmbedder from the provided `options`.
|
||||
// Returns a pointer to the text embedder on success.
|
||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
// Performs embedding extraction on the input `text`. Returns `0` on success.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||
TextEmbedderResult* result,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
// Frees the memory allocated inside a TextEmbedderResult result. Does not
|
||||
// free the result pointer itself.
|
||||
MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
|
||||
|
||||
// Shuts down the TextEmbedder when all the work is done. Frees all memory.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int text_embedder_close(void* embedder, char** error_msg = nullptr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
79
mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc
Normal file
79
mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc
Normal file
|
@ -0,0 +1,79 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using testing::HasSubstr;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestBertModelPath[] =
|
||||
"mobilebert_embedding_with_metadata.tflite";
|
||||
constexpr char kTestString[] = "It's beautiful outside.";
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
TEST(TextEmbedderTest, SmokeTest) {
|
||||
std::string model_path = GetFullPath(kTestBertModelPath);
|
||||
TextEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ false, /* quantize= */ true},
|
||||
};
|
||||
|
||||
void* embedder = text_embedder_create(&options);
|
||||
EXPECT_NE(embedder, nullptr);
|
||||
|
||||
TextEmbedderResult result;
|
||||
text_embedder_embed(embedder, kTestString, &result);
|
||||
EXPECT_EQ(result.embeddings_count, 1);
|
||||
EXPECT_EQ(result.embeddings[0].values_count, 512);
|
||||
|
||||
text_embedder_close_result(&result);
|
||||
text_embedder_close(embedder);
|
||||
}
|
||||
|
||||
TEST(TextEmbedderTest, ErrorHandling) {
|
||||
// It is an error to set neither the asset buffer nor the path.
|
||||
TextEmbedderOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ nullptr},
|
||||
/* embedder_options= */ {},
|
||||
};
|
||||
|
||||
char* error_msg;
|
||||
void* embedder = text_embedder_create(&options, &error_msg);
|
||||
EXPECT_EQ(embedder, nullptr);
|
||||
|
||||
EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT"));
|
||||
|
||||
free(error_msg);
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user