Merge pull request #4852 from kinaryml:c-text-embedder-api
PiperOrigin-RevId: 571065228
This commit is contained in:
commit
d686b42b85
|
@ -71,3 +71,30 @@ cc_test(
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_result",
|
||||||
|
hdrs = ["embedding_result.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_result_converter",
|
||||||
|
srcs = ["embedding_result_converter.cc"],
|
||||||
|
hdrs = ["embedding_result_converter.h"],
|
||||||
|
deps = [
|
||||||
|
":embedding_result",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "embedding_result_converter_test",
|
||||||
|
srcs = ["embedding_result_converter_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":embedding_result",
|
||||||
|
":embedding_result_converter",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
79
mediapipe/tasks/c/components/containers/embedding_result.h
Normal file
79
mediapipe/tasks/c/components/containers/embedding_result.h
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||||
|
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Embedding result for a given embedder head.
|
||||||
|
//
|
||||||
|
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
|
||||||
|
// contain data, based on whether or not the embedder was configured to perform
|
||||||
|
// scalar quantization.
|
||||||
|
struct Embedding {
|
||||||
|
// Floating-point embedding. `nullptr` if the embedder was configured to
|
||||||
|
// perform scalar-quantization.
|
||||||
|
float* float_embedding;
|
||||||
|
|
||||||
|
// Scalar-quantized embedding. `nullptr` if the embedder was not configured to
|
||||||
|
// perform scalar quantization.
|
||||||
|
char* quantized_embedding;
|
||||||
|
|
||||||
|
// Keep the count of embedding values.
|
||||||
|
uint32_t values_count;
|
||||||
|
|
||||||
|
// The index of the embedder head (i.e. output tensor) this embedding comes
|
||||||
|
// from. This is useful for multi-head models.
|
||||||
|
int head_index;
|
||||||
|
|
||||||
|
// The optional name of the embedder head, as provided in the TFLite Model
|
||||||
|
// Metadata [1] if present. This is useful for multi-head models.
|
||||||
|
// Defaults to nullptr.
|
||||||
|
//
|
||||||
|
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||||
|
char* head_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Defines embedding results of a model.
|
||||||
|
struct EmbeddingResult {
|
||||||
|
// The embedding results for each head of the model.
|
||||||
|
struct Embedding* embeddings;
|
||||||
|
|
||||||
|
// Keep the count of embeddings.
|
||||||
|
uint32_t embeddings_count;
|
||||||
|
|
||||||
|
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
// corresponding to these results.
|
||||||
|
//
|
||||||
|
// This is only used for embedding extraction on time series (e.g. audio
|
||||||
|
// embedding). In these use cases, the amount of data to process might
|
||||||
|
// exceed the maximum size that the model can process: to solve this, the
|
||||||
|
// input data is split into multiple chunks starting at different timestamps.
|
||||||
|
int64_t timestamp_ms;
|
||||||
|
|
||||||
|
// Specifies whether the timestamp contains a valid value.
|
||||||
|
bool has_timestamp_ms;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern C
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
|
@ -0,0 +1,84 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::containers {
|
||||||
|
|
||||||
|
void CppConvertToEmbeddingResult(
|
||||||
|
const mediapipe::tasks::components::containers::EmbeddingResult& in,
|
||||||
|
EmbeddingResult* out) {
|
||||||
|
out->has_timestamp_ms = in.timestamp_ms.has_value();
|
||||||
|
out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0;
|
||||||
|
|
||||||
|
out->embeddings_count = in.embeddings.size();
|
||||||
|
out->embeddings =
|
||||||
|
out->embeddings_count ? new Embedding[out->embeddings_count] : nullptr;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < out->embeddings_count; ++i) {
|
||||||
|
auto embedding_in = in.embeddings[i];
|
||||||
|
auto& embedding_out = out->embeddings[i];
|
||||||
|
if (!embedding_in.float_embedding.empty()) {
|
||||||
|
// Handle float embeddings
|
||||||
|
embedding_out.values_count = embedding_in.float_embedding.size();
|
||||||
|
embedding_out.float_embedding = new float[embedding_out.values_count];
|
||||||
|
|
||||||
|
std::copy(embedding_in.float_embedding.begin(),
|
||||||
|
embedding_in.float_embedding.end(),
|
||||||
|
embedding_out.float_embedding);
|
||||||
|
|
||||||
|
embedding_out.quantized_embedding = nullptr;
|
||||||
|
} else if (!embedding_in.quantized_embedding.empty()) {
|
||||||
|
// Handle quantized embeddings
|
||||||
|
embedding_out.values_count = embedding_in.quantized_embedding.size();
|
||||||
|
embedding_out.quantized_embedding = new char[embedding_out.values_count];
|
||||||
|
|
||||||
|
std::copy(embedding_in.quantized_embedding.begin(),
|
||||||
|
embedding_in.quantized_embedding.end(),
|
||||||
|
embedding_out.quantized_embedding);
|
||||||
|
|
||||||
|
embedding_out.float_embedding = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_out.head_index = embedding_in.head_index;
|
||||||
|
embedding_out.head_name = embedding_in.head_name.has_value()
|
||||||
|
? strdup(embedding_in.head_name->c_str())
|
||||||
|
: nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CppCloseEmbeddingResult(EmbeddingResult* in) {
|
||||||
|
for (uint32_t i = 0; i < in->embeddings_count; ++i) {
|
||||||
|
auto embedding_in = in->embeddings[i];
|
||||||
|
|
||||||
|
delete[] embedding_in.float_embedding;
|
||||||
|
delete[] embedding_in.quantized_embedding;
|
||||||
|
embedding_in.float_embedding = nullptr;
|
||||||
|
embedding_in.quantized_embedding = nullptr;
|
||||||
|
|
||||||
|
free(embedding_in.head_name);
|
||||||
|
}
|
||||||
|
delete[] in->embeddings;
|
||||||
|
in->embeddings = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -0,0 +1,38 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::containers {
|
||||||
|
|
||||||
|
void CppConvertToEmbedding(
|
||||||
|
const mediapipe::tasks::components::containers::EmbeddingResult& in,
|
||||||
|
Embedding* out);
|
||||||
|
|
||||||
|
void CppConvertToEmbeddingResult(
|
||||||
|
const mediapipe::tasks::components::containers::EmbeddingResult& in,
|
||||||
|
EmbeddingResult* out);
|
||||||
|
|
||||||
|
void CppCloseEmbedding(Embedding* in);
|
||||||
|
|
||||||
|
void CppCloseEmbeddingResult(EmbeddingResult* in);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
|
|
@ -0,0 +1,66 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::containers {
|
||||||
|
|
||||||
|
TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) {
|
||||||
|
mediapipe::tasks::components::containers::EmbeddingResult
|
||||||
|
cpp_embedding_result = {
|
||||||
|
// Initializing embeddings vector
|
||||||
|
{// First embedding
|
||||||
|
{
|
||||||
|
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, // float embedding
|
||||||
|
{}, // quantized embedding (empty)
|
||||||
|
0, // head index
|
||||||
|
"foo" // head name
|
||||||
|
},
|
||||||
|
// Second embedding
|
||||||
|
{
|
||||||
|
{}, // float embedding (empty)
|
||||||
|
{127, 127, 127, 127, 127}, // quantized embedding
|
||||||
|
1, // head index
|
||||||
|
std::nullopt // no head name
|
||||||
|
}},
|
||||||
|
// Initializing timestamp_ms
|
||||||
|
42 // timestamp in ms
|
||||||
|
};
|
||||||
|
|
||||||
|
EmbeddingResult c_embedding_result;
|
||||||
|
CppConvertToEmbeddingResult(cpp_embedding_result, &c_embedding_result);
|
||||||
|
EXPECT_NE(c_embedding_result.embeddings, nullptr);
|
||||||
|
EXPECT_EQ(c_embedding_result.embeddings_count, 2);
|
||||||
|
EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
|
||||||
|
EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
|
||||||
|
EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
|
||||||
|
EXPECT_NE(c_embedding_result.embeddings[1].quantized_embedding, nullptr);
|
||||||
|
EXPECT_EQ(c_embedding_result.embeddings[1].values_count, 5);
|
||||||
|
EXPECT_EQ(c_embedding_result.embeddings[1].head_index, 1);
|
||||||
|
EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
|
||||||
|
EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
|
||||||
|
EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);
|
||||||
|
|
||||||
|
CppCloseEmbeddingResult(&c_embedding_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -42,3 +42,30 @@ cc_test(
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedder_options",
|
||||||
|
hdrs = ["embedder_options.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedder_options_converter",
|
||||||
|
srcs = ["embedder_options_converter.cc"],
|
||||||
|
hdrs = ["embedder_options_converter.h"],
|
||||||
|
deps = [
|
||||||
|
":embedder_options",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "embedder_options_converter_test",
|
||||||
|
srcs = ["embedder_options_converter_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":embedder_options",
|
||||||
|
":embedder_options_converter",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
42
mediapipe/tasks/c/components/processors/embedder_options.h
Normal file
42
mediapipe/tasks/c/components/processors/embedder_options.h
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Embedder options for MediaPipe C embedding extraction tasks.
|
||||||
|
struct EmbedderOptions {
|
||||||
|
// Whether to normalize the returned feature vector with L2 norm. Use this
|
||||||
|
// option only if the model does not already contain a native L2_NORMALIZATION
|
||||||
|
// TF Lite Op. In most cases, this is already the case and L2 norm is thus
|
||||||
|
// achieved through TF Lite inference.
|
||||||
|
bool l2_normalize;
|
||||||
|
|
||||||
|
// Whether the returned embedding should be quantized to bytes via scalar
|
||||||
|
// quantization. Embeddings are implicitly assumed to be unit-norm and
|
||||||
|
// therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
|
||||||
|
// the l2_normalize option if this is not the case.
|
||||||
|
bool quantize;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern C
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
|
@ -0,0 +1,28 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::processors {
|
||||||
|
|
||||||
|
void CppConvertToEmbedderOptions(
|
||||||
|
const EmbedderOptions& in,
|
||||||
|
mediapipe::tasks::components::processors::EmbedderOptions* out) {
|
||||||
|
out->l2_normalize = in.l2_normalize;
|
||||||
|
out->quantize = in.quantize;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::processors
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::processors {
|
||||||
|
|
||||||
|
void CppConvertToEmbedderOptions(
|
||||||
|
const EmbedderOptions& in,
|
||||||
|
mediapipe::tasks::components::processors::EmbedderOptions* out);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::processors
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::processors {
|
||||||
|
|
||||||
|
TEST(EmbedderOptionsConverterTest, ConvertsEmbedderOptionsCustomValues) {
|
||||||
|
EmbedderOptions c_embedder_options = {/* l2_normalize= */ true,
|
||||||
|
/* quantize= */ false};
|
||||||
|
|
||||||
|
mediapipe::tasks::components::processors::EmbedderOptions
|
||||||
|
cpp_embedder_options = {};
|
||||||
|
|
||||||
|
CppConvertToEmbedderOptions(c_embedder_options, &cpp_embedder_options);
|
||||||
|
EXPECT_EQ(cpp_embedder_options.l2_normalize, true);
|
||||||
|
EXPECT_EQ(cpp_embedder_options.quantize, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::processors
|
85
mediapipe/tasks/c/text/text_embedder/BUILD
Normal file
85
mediapipe/tasks/c/text/text_embedder/BUILD
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "text_embedder_lib",
|
||||||
|
srcs = ["text_embedder.cc"],
|
||||||
|
hdrs = ["text_embedder.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/c/components/containers:embedding_result",
|
||||||
|
"//mediapipe/tasks/c/components/containers:embedding_result_converter",
|
||||||
|
"//mediapipe/tasks/c/components/processors:embedder_options",
|
||||||
|
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
|
||||||
|
"//mediapipe/tasks/c/core:base_options",
|
||||||
|
"//mediapipe/tasks/c/core:base_options_converter",
|
||||||
|
"//mediapipe/tasks/cc/text/text_embedder",
|
||||||
|
"@com_google_absl//absl/log:absl_log",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
|
||||||
|
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.so
|
||||||
|
cc_binary(
|
||||||
|
name = "libtext_embedder.so",
|
||||||
|
linkopts = [
|
||||||
|
"-Wl,-soname=libtext_embedder.so",
|
||||||
|
"-fvisibility=hidden",
|
||||||
|
],
|
||||||
|
linkshared = True,
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [":text_embedder_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
|
||||||
|
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.dylib
|
||||||
|
cc_binary(
|
||||||
|
name = "libtext_embedder.dylib",
|
||||||
|
linkopts = [
|
||||||
|
"-Wl,-install_name,libtext_embedder.dylib",
|
||||||
|
"-fvisibility=hidden",
|
||||||
|
],
|
||||||
|
linkshared = True,
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [":text_embedder_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "text_embedder_test",
|
||||||
|
srcs = ["text_embedder_test.cc"],
|
||||||
|
data = ["//mediapipe/tasks/testdata/text:mobilebert_embedding_model"],
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
":text_embedder_lib",
|
||||||
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
119
mediapipe/tasks/c/text/text_embedder/text_embedder.cc
Normal file
119
mediapipe/tasks/c/text/text_embedder/text_embedder.cc
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/log/absl_log.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
|
||||||
|
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
|
||||||
|
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::text::text_embedder {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
|
||||||
|
using ::mediapipe::tasks::c::components::containers::
|
||||||
|
CppConvertToEmbeddingResult;
|
||||||
|
using ::mediapipe::tasks::c::components::processors::
|
||||||
|
CppConvertToEmbedderOptions;
|
||||||
|
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||||
|
using ::mediapipe::tasks::text::text_embedder::TextEmbedder;
|
||||||
|
|
||||||
|
int CppProcessError(absl::Status status, char** error_msg) {
|
||||||
|
if (error_msg) {
|
||||||
|
*error_msg = strdup(status.ToString().c_str());
|
||||||
|
}
|
||||||
|
return status.raw_code();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TextEmbedder* CppTextEmbedderCreate(const TextEmbedderOptions& options,
|
||||||
|
char** error_msg) {
|
||||||
|
auto cpp_options = std::make_unique<
|
||||||
|
::mediapipe::tasks::text::text_embedder::TextEmbedderOptions>();
|
||||||
|
|
||||||
|
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
|
||||||
|
CppConvertToEmbedderOptions(options.embedder_options,
|
||||||
|
&cpp_options->embedder_options);
|
||||||
|
|
||||||
|
auto embedder = TextEmbedder::Create(std::move(cpp_options));
|
||||||
|
if (!embedder.ok()) {
|
||||||
|
ABSL_LOG(ERROR) << "Failed to create TextEmbedder: " << embedder.status();
|
||||||
|
CppProcessError(embedder.status(), error_msg);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return embedder->release();
|
||||||
|
}
|
||||||
|
|
||||||
|
int CppTextEmbedderEmbed(void* embedder, const char* utf8_str,
|
||||||
|
TextEmbedderResult* result, char** error_msg) {
|
||||||
|
auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
|
||||||
|
auto cpp_result = cpp_embedder->Embed(utf8_str);
|
||||||
|
if (!cpp_result.ok()) {
|
||||||
|
ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
|
||||||
|
return CppProcessError(cpp_result.status(), error_msg);
|
||||||
|
}
|
||||||
|
CppConvertToEmbeddingResult(*cpp_result, result);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CppTextEmbedderCloseResult(TextEmbedderResult* result) {
|
||||||
|
CppCloseEmbeddingResult(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
int CppTextEmbedderClose(void* embedder, char** error_msg) {
|
||||||
|
auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
|
||||||
|
auto result = cpp_embedder->Close();
|
||||||
|
if (!result.ok()) {
|
||||||
|
ABSL_LOG(ERROR) << "Failed to close TextEmbedder: " << result;
|
||||||
|
return CppProcessError(result, error_msg);
|
||||||
|
}
|
||||||
|
delete cpp_embedder;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::text::text_embedder
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||||
|
char** error_msg) {
|
||||||
|
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
|
||||||
|
*options, error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||||
|
TextEmbedderResult* result, char** error_msg) {
|
||||||
|
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
|
||||||
|
embedder, utf8_str, result, error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
void text_embedder_close_result(TextEmbedderResult* result) {
|
||||||
|
mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCloseResult(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
int text_embedder_close(void* embedder, char** error_ms) {
|
||||||
|
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderClose(
|
||||||
|
embedder, error_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
74
mediapipe/tasks/c/text/text_embedder/text_embedder.h
Normal file
74
mediapipe/tasks/c/text/text_embedder/text_embedder.h
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
|
||||||
|
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
|
||||||
|
#include "mediapipe/tasks/c/core/base_options.h"
|
||||||
|
|
||||||
|
#ifndef MP_EXPORT
|
||||||
|
#define MP_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif // MP_EXPORT
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef struct EmbeddingResult TextEmbedderResult;
|
||||||
|
|
||||||
|
// The options for configuring a MediaPipe text embedder task.
|
||||||
|
struct TextEmbedderOptions {
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
|
struct BaseOptions base_options;
|
||||||
|
|
||||||
|
// Options for configuring the embedder behavior, such as l2_normalize
|
||||||
|
// and quantize.
|
||||||
|
struct EmbedderOptions embedder_options;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates a TextEmbedder from the provided `options`.
|
||||||
|
// Returns a pointer to the text embedder on success.
|
||||||
|
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||||
|
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||||
|
// allocated for the error message.
|
||||||
|
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||||
|
char** error_msg = nullptr);
|
||||||
|
|
||||||
|
// Performs embedding extraction on the input `text`. Returns `0` on success.
|
||||||
|
// If an error occurs, returns an error code and sets the error parameter to an
|
||||||
|
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||||
|
// allocated for the error message.
|
||||||
|
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||||
|
TextEmbedderResult* result,
|
||||||
|
char** error_msg = nullptr);
|
||||||
|
|
||||||
|
// Frees the memory allocated inside a TextEmbedderResult result. Does not
|
||||||
|
// free the result pointer itself.
|
||||||
|
MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
|
||||||
|
|
||||||
|
// Shuts down the TextEmbedder when all the work is done. Frees all memory.
|
||||||
|
// If an error occurs, returns an error code and sets the error parameter to an
|
||||||
|
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||||
|
// allocated for the error message.
|
||||||
|
MP_EXPORT int text_embedder_close(void* embedder, char** error_msg = nullptr);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern C
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
79
mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc
Normal file
79
mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using testing::HasSubstr;
|
||||||
|
|
||||||
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||||
|
constexpr char kTestBertModelPath[] =
|
||||||
|
"mobilebert_embedding_with_metadata.tflite";
|
||||||
|
constexpr char kTestString[] = "It's beautiful outside.";
|
||||||
|
|
||||||
|
std::string GetFullPath(absl::string_view file_name) {
|
||||||
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TextEmbedderTest, SmokeTest) {
|
||||||
|
std::string model_path = GetFullPath(kTestBertModelPath);
|
||||||
|
TextEmbedderOptions options = {
|
||||||
|
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||||
|
/* model_asset_path= */ model_path.c_str()},
|
||||||
|
/* embedder_options= */
|
||||||
|
{/* l2_normalize= */ false, /* quantize= */ true},
|
||||||
|
};
|
||||||
|
|
||||||
|
void* embedder = text_embedder_create(&options);
|
||||||
|
EXPECT_NE(embedder, nullptr);
|
||||||
|
|
||||||
|
TextEmbedderResult result;
|
||||||
|
text_embedder_embed(embedder, kTestString, &result);
|
||||||
|
EXPECT_EQ(result.embeddings_count, 1);
|
||||||
|
EXPECT_EQ(result.embeddings[0].values_count, 512);
|
||||||
|
|
||||||
|
text_embedder_close_result(&result);
|
||||||
|
text_embedder_close(embedder);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TextEmbedderTest, ErrorHandling) {
|
||||||
|
// It is an error to set neither the asset buffer nor the path.
|
||||||
|
TextEmbedderOptions options = {
|
||||||
|
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||||
|
/* model_asset_path= */ nullptr},
|
||||||
|
/* embedder_options= */ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
char* error_msg;
|
||||||
|
void* embedder = text_embedder_create(&options, &error_msg);
|
||||||
|
EXPECT_EQ(embedder, nullptr);
|
||||||
|
|
||||||
|
EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT"));
|
||||||
|
|
||||||
|
free(error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
Loading…
Reference in New Issue
Block a user