Added files for the TextEmbedder C API and tests
This commit is contained in:
		
							parent
							
								
									5366aa9d0a
								
							
						
					
					
						commit
						3564fc0d9b
					
				| 
						 | 
				
			
			@ -71,3 +71,30 @@ cc_test(
 | 
			
		|||
        "@com_google_googletest//:gtest_main",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "embedding_result",
 | 
			
		||||
    hdrs = ["embedding_result.h"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "embedding_result_converter",
 | 
			
		||||
    srcs = ["embedding_result_converter.cc"],
 | 
			
		||||
    hdrs = ["embedding_result_converter.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":embedding_result",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:embedding_result",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "embedding_result_converter_test",
 | 
			
		||||
    srcs = ["embedding_result_converter_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":embedding_result",
 | 
			
		||||
        ":embedding_result_converter",
 | 
			
		||||
        "//mediapipe/framework/port:gtest",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:embedding_result",
 | 
			
		||||
        "@com_google_googletest//:gtest_main",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										79
									
								
								mediapipe/tasks/c/components/containers/embedding_result.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								mediapipe/tasks/c/components/containers/embedding_result.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,79 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
 | 
			
		||||
 | 
			
		||||
#include <stdbool.h>
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
 | 
			
		||||
#ifdef __cplusplus
 | 
			
		||||
extern "C" {
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Embedding result for a given embedder head.
 | 
			
		||||
//
 | 
			
		||||
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
 | 
			
		||||
// contain data, based on whether or not the embedder was configured to perform
 | 
			
		||||
// scalar quantization.
 | 
			
		||||
struct Embedding {
 | 
			
		||||
  // Floating-point embedding. Empty if the embedder was configured to perform
 | 
			
		||||
  // scalar-quantization.
 | 
			
		||||
  float* float_embedding;
 | 
			
		||||
 | 
			
		||||
  // Scalar-quantized embedding. Empty if the embedder was not configured to
 | 
			
		||||
  // perform scalar quantization.
 | 
			
		||||
  char* quantized_embedding;
 | 
			
		||||
 | 
			
		||||
  // Keep the count of embedding values.
 | 
			
		||||
  uint32_t values_count;
 | 
			
		||||
 | 
			
		||||
  // The index of the embedder head (i.e. output tensor) this embedding comes
 | 
			
		||||
  // from. This is useful for multi-head models.
 | 
			
		||||
  int head_index;
 | 
			
		||||
 | 
			
		||||
  // The optional name of the embedder head, as provided in the TFLite Model
 | 
			
		||||
  // Metadata [1] if present. This is useful for multi-head models.
 | 
			
		||||
  //
 | 
			
		||||
  // [1]: https://www.tensorflow.org/lite/convert/metadata
 | 
			
		||||
  char* head_name;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Defines embedding results of a model.
 | 
			
		||||
struct EmbeddingResult {
 | 
			
		||||
  // The embedding results for each head of the model.
 | 
			
		||||
  struct Embedding* embeddings;
 | 
			
		||||
 | 
			
		||||
  // Keep the count of embeddings.
 | 
			
		||||
  uint32_t embeddings_count;
 | 
			
		||||
 | 
			
		||||
  // The optional timestamp (in milliseconds) of the start of the chunk of data
 | 
			
		||||
  // corresponding to these results.
 | 
			
		||||
  //
 | 
			
		||||
  // This is only used for classification on time series (e.g. audio
 | 
			
		||||
  // classification). In these use cases, the amount of data to process might
 | 
			
		||||
  // exceed the maximum size that the model can process: to solve this, the
 | 
			
		||||
  // input data is split into multiple chunks starting at different timestamps.
 | 
			
		||||
  int64_t timestamp_ms;
 | 
			
		||||
 | 
			
		||||
  // Specifies whether the timestamp contains a valid value.
 | 
			
		||||
  bool has_timestamp_ms;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#ifdef __cplusplus
 | 
			
		||||
}  // extern C
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,79 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::containers {
 | 
			
		||||
 | 
			
		||||
void CppConvertToEmbeddingResult(
 | 
			
		||||
    const mediapipe::tasks::components::containers::EmbeddingResult& in,
 | 
			
		||||
    EmbeddingResult* out) {
 | 
			
		||||
  out->has_timestamp_ms = in.timestamp_ms.has_value();
 | 
			
		||||
  out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0;
 | 
			
		||||
 | 
			
		||||
  out->embeddings_count = in.embeddings.size();
 | 
			
		||||
  out->embeddings =
 | 
			
		||||
      out->embeddings_count ? new Embedding[out->embeddings_count] : nullptr;
 | 
			
		||||
 | 
			
		||||
  for (uint32_t i = 0; i < out->embeddings_count; ++i) {
 | 
			
		||||
    auto embedding_in = in.embeddings[i];
 | 
			
		||||
    auto& embedding_out = out->embeddings[i];
 | 
			
		||||
 | 
			
		||||
    if (!embedding_in.float_embedding.empty()) {
 | 
			
		||||
      embedding_out.values_count = embedding_in.float_embedding.size();
 | 
			
		||||
      embedding_out.float_embedding =
 | 
			
		||||
          embedding_out.values_count ? new float[embedding_out.values_count]
 | 
			
		||||
                                     : nullptr;
 | 
			
		||||
      std::copy(embedding_in.float_embedding.begin(),
 | 
			
		||||
                embedding_in.float_embedding.end(),
 | 
			
		||||
                embedding_out.float_embedding);
 | 
			
		||||
      embedding_out.quantized_embedding = nullptr;
 | 
			
		||||
    } else if (!embedding_in.quantized_embedding.empty()) {
 | 
			
		||||
      embedding_out.values_count = embedding_in.quantized_embedding.size();
 | 
			
		||||
      embedding_out.quantized_embedding =
 | 
			
		||||
          embedding_out.values_count ? new char[embedding_out.values_count]
 | 
			
		||||
                                     : nullptr;
 | 
			
		||||
      std::copy(embedding_in.quantized_embedding.begin(),
 | 
			
		||||
                embedding_in.quantized_embedding.end(),
 | 
			
		||||
                embedding_out.quantized_embedding);
 | 
			
		||||
      embedding_out.float_embedding = nullptr;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    embedding_out.head_index = embedding_in.head_index;
 | 
			
		||||
    embedding_out.head_name = embedding_in.head_name.has_value()
 | 
			
		||||
                                  ? strdup(embedding_in.head_name->c_str())
 | 
			
		||||
                                  : nullptr;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void CppCloseEmbeddingResult(EmbeddingResult* in) {
 | 
			
		||||
  for (uint32_t i = 0; i < in->embeddings_count; ++i) {
 | 
			
		||||
    auto embedding_in = in->embeddings[i];
 | 
			
		||||
 | 
			
		||||
    delete[] embedding_in.float_embedding;
 | 
			
		||||
    delete[] embedding_in.quantized_embedding;
 | 
			
		||||
 | 
			
		||||
    free(embedding_in.head_name);
 | 
			
		||||
  }
 | 
			
		||||
  delete[] in->embeddings;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::containers
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,38 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::containers {
 | 
			
		||||
 | 
			
		||||
void CppConvertToEmbedding(
 | 
			
		||||
    const mediapipe::tasks::components::containers::EmbeddingResult& in,
 | 
			
		||||
    Embedding* out);
 | 
			
		||||
 | 
			
		||||
void CppConvertToEmbeddingResult(
 | 
			
		||||
    const mediapipe::tasks::components::containers::EmbeddingResult& in,
 | 
			
		||||
    EmbeddingResult* out);
 | 
			
		||||
 | 
			
		||||
void CppCloseEmbedding(Embedding* in);
 | 
			
		||||
 | 
			
		||||
void CppCloseEmbeddingResult(EmbeddingResult* in);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::containers
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_CONVERTER_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,65 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
 | 
			
		||||
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::containers {
 | 
			
		||||
 | 
			
		||||
TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) {
 | 
			
		||||
  mediapipe::tasks::components::containers::EmbeddingResult
 | 
			
		||||
 | 
			
		||||
      cpp_embedding_result = {
 | 
			
		||||
          // Initializing embeddings vector
 | 
			
		||||
          {// First embedding
 | 
			
		||||
           {
 | 
			
		||||
               {0.1f, 0.2f, 0.3f, 0.4f, 0.5f},  // float embedding
 | 
			
		||||
               {},                              // quantized embedding (empty)
 | 
			
		||||
               0,                               // head index
 | 
			
		||||
               "foo"                            // head name
 | 
			
		||||
           },
 | 
			
		||||
           // Second embedding
 | 
			
		||||
           {
 | 
			
		||||
               {},                         // float embedding (empty)
 | 
			
		||||
               {127, 127, 127, 127, 127},  // quantized embedding
 | 
			
		||||
               1,                          // head index
 | 
			
		||||
               std::nullopt                // no head name
 | 
			
		||||
           }},
 | 
			
		||||
          // Initializing timestamp_ms
 | 
			
		||||
          42  // timestamp in ms
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
  EmbeddingResult c_embedding_result;
 | 
			
		||||
  CppConvertToEmbeddingResult(cpp_embedding_result, &c_embedding_result);
 | 
			
		||||
  EXPECT_NE(c_embedding_result.embeddings, nullptr);
 | 
			
		||||
  EXPECT_EQ(c_embedding_result.embeddings_count, 2);
 | 
			
		||||
  EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
 | 
			
		||||
  EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
 | 
			
		||||
  EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
 | 
			
		||||
  EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
 | 
			
		||||
  EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
 | 
			
		||||
  EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);
 | 
			
		||||
 | 
			
		||||
  CppCloseEmbeddingResult(&c_embedding_result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::containers
 | 
			
		||||
| 
						 | 
				
			
			@ -42,3 +42,30 @@ cc_test(
 | 
			
		|||
        "@com_google_googletest//:gtest_main",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "embedder_options",
 | 
			
		||||
    hdrs = ["embedder_options.h"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "embedder_options_converter",
 | 
			
		||||
    srcs = ["embedder_options_converter.cc"],
 | 
			
		||||
    hdrs = ["embedder_options_converter.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":embedder_options",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:embedder_options",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "embedder_options_converter_test",
 | 
			
		||||
    srcs = ["embedder_options_converter_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":embedder_options",
 | 
			
		||||
        ":embedder_options_converter",
 | 
			
		||||
        "//mediapipe/framework/port:gtest",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:embedder_options",
 | 
			
		||||
        "@com_google_googletest//:gtest_main",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										44
									
								
								mediapipe/tasks/c/components/processors/embedder_options.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								mediapipe/tasks/c/components/processors/embedder_options.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,44 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
 | 
			
		||||
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
 | 
			
		||||
#ifdef __cplusplus
 | 
			
		||||
extern "C" {
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Embedder options for MediaPipe C embedding extraction tasks.
 | 
			
		||||
struct EmbedderOptions {
 | 
			
		||||
  // Whether to normalize the returned feature vector with L2 norm. Use this
 | 
			
		||||
  // option only if the model does not already contain a native L2_NORMALIZATION
 | 
			
		||||
  // TF Lite Op. In most cases, this is already the case and L2 norm is thus
 | 
			
		||||
  // achieved through TF Lite inference.
 | 
			
		||||
  bool l2_normalize;
 | 
			
		||||
 | 
			
		||||
  // Whether the returned embedding should be quantized to bytes via scalar
 | 
			
		||||
  // quantization. Embeddings are implicitly assumed to be unit-norm and
 | 
			
		||||
  // therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
 | 
			
		||||
  // the l2_normalize option if this is not the case.
 | 
			
		||||
  bool quantize;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#ifdef __cplusplus
 | 
			
		||||
}  // extern C
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,32 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::processors {
 | 
			
		||||
 | 
			
		||||
void CppConvertToEmbedderOptions(
 | 
			
		||||
    const EmbedderOptions& in,
 | 
			
		||||
    mediapipe::tasks::components::processors::EmbedderOptions* out) {
 | 
			
		||||
  out->l2_normalize = in.l2_normalize;
 | 
			
		||||
  out->quantize = in.quantize;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::processors
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::processors {
 | 
			
		||||
 | 
			
		||||
void CppConvertToEmbedderOptions(
 | 
			
		||||
    const EmbedderOptions& in,
 | 
			
		||||
    mediapipe::tasks::components::processors::EmbedderOptions* out);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::processors
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_CONVERTER_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,45 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::components::processors {
 | 
			
		||||
 | 
			
		||||
constexpr char kCategoryAllowlist[] = "fruit";
 | 
			
		||||
constexpr char kCategoryDenylist[] = "veggies";
 | 
			
		||||
constexpr char kDisplayNamesLocaleGerman[] = "de";
 | 
			
		||||
 | 
			
		||||
TEST(EmbedderOptionsConverterTest, ConvertsEmbedderOptionsCustomValues) {
 | 
			
		||||
  EmbedderOptions c_embedder_options = {
 | 
			
		||||
      /* l2_normalize= */ true,
 | 
			
		||||
      /* quantize= */ false
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  mediapipe::tasks::components::processors::EmbedderOptions
 | 
			
		||||
      cpp_embedder_options = {};
 | 
			
		||||
 | 
			
		||||
  CppConvertToEmbedderOptions(c_embedder_options, &cpp_embedder_options);
 | 
			
		||||
  EXPECT_EQ(cpp_embedder_options.l2_normalize, true);
 | 
			
		||||
  EXPECT_EQ(cpp_embedder_options.quantize, false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::components::processors
 | 
			
		||||
							
								
								
									
										85
									
								
								mediapipe/tasks/c/text/text_embedder/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								mediapipe/tasks/c/text/text_embedder/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,85 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "text_embedder_lib",
 | 
			
		||||
    srcs = ["text_embedder.cc"],
 | 
			
		||||
    hdrs = ["text_embedder.h"],
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/c/components/containers:embedding_result",
 | 
			
		||||
        "//mediapipe/tasks/c/components/containers:embedding_result_converter",
 | 
			
		||||
        "//mediapipe/tasks/c/components/processors:embedder_options",
 | 
			
		||||
        "//mediapipe/tasks/c/components/processors:embedder_options_converter",
 | 
			
		||||
        "//mediapipe/tasks/c/core:base_options",
 | 
			
		||||
        "//mediapipe/tasks/c/core:base_options_converter",
 | 
			
		||||
        "//mediapipe/tasks/cc/text/text_embedder",
 | 
			
		||||
        "@com_google_absl//absl/log:absl_log",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
 | 
			
		||||
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.so
 | 
			
		||||
cc_binary(
 | 
			
		||||
    name = "libtext_embedder.so",
 | 
			
		||||
    linkopts = [
 | 
			
		||||
        "-Wl,-soname=libtext_embedder.so",
 | 
			
		||||
        "-fvisibility=hidden",
 | 
			
		||||
    ],
 | 
			
		||||
    linkshared = True,
 | 
			
		||||
    tags = [
 | 
			
		||||
        "manual",
 | 
			
		||||
        "nobuilder",
 | 
			
		||||
        "notap",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [":text_embedder_lib"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
 | 
			
		||||
# //mediapipe/tasks/c/text/text_embedder:libtext_embedder.dylib
 | 
			
		||||
cc_binary(
 | 
			
		||||
    name = "libtext_embedder.dylib",
 | 
			
		||||
    linkopts = [
 | 
			
		||||
        "-Wl,-install_name,libtext_embedder.dylib",
 | 
			
		||||
        "-fvisibility=hidden",
 | 
			
		||||
    ],
 | 
			
		||||
    linkshared = True,
 | 
			
		||||
    tags = [
 | 
			
		||||
        "manual",
 | 
			
		||||
        "nobuilder",
 | 
			
		||||
        "notap",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [":text_embedder_lib"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "text_embedder_test",
 | 
			
		||||
    srcs = ["text_embedder_test.cc"],
 | 
			
		||||
    data = ["//mediapipe/tasks/testdata/text:mobilebert_embedding_model"],
 | 
			
		||||
    linkstatic = 1,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":text_embedder_lib",
 | 
			
		||||
        "//mediapipe/framework/deps:file_path",
 | 
			
		||||
        "//mediapipe/framework/port:gtest",
 | 
			
		||||
        "@com_google_absl//absl/flags:flag",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@com_google_googletest//:gtest_main",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										124
									
								
								mediapipe/tasks/c/text/text_embedder/text_embedder.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								mediapipe/tasks/c/text/text_embedder/text_embedder.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,124 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "absl/log/absl_log.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
 | 
			
		||||
#include "mediapipe/tasks/c/core/base_options.h"
 | 
			
		||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::c::text::text_embedder {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::tasks::c::components::containers::
 | 
			
		||||
    CppCloseEmbeddingResult;
 | 
			
		||||
using ::mediapipe::tasks::c::components::containers::
 | 
			
		||||
    CppConvertToEmbeddingResult;
 | 
			
		||||
using ::mediapipe::tasks::c::components::processors::
 | 
			
		||||
    CppConvertToEmbedderOptions;
 | 
			
		||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
 | 
			
		||||
using ::mediapipe::tasks::text::text_embedder::TextEmbedder;
 | 
			
		||||
 | 
			
		||||
int CppProcessError(absl::Status status, char** error_msg) {
 | 
			
		||||
  if (error_msg) {
 | 
			
		||||
    *error_msg = strdup(status.ToString().c_str());
 | 
			
		||||
  }
 | 
			
		||||
  return status.raw_code();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
TextEmbedder* CppTextEmbedderCreate(const TextEmbedderOptions& options,
 | 
			
		||||
                                        char** error_msg) {
 | 
			
		||||
  auto cpp_options = std::make_unique<
 | 
			
		||||
      ::mediapipe::tasks::text::text_embedder::TextEmbedderOptions>();
 | 
			
		||||
 | 
			
		||||
  CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
 | 
			
		||||
  CppConvertToEmbedderOptions(options.embedder_options,
 | 
			
		||||
                                &cpp_options->embedder_options);
 | 
			
		||||
 | 
			
		||||
  auto embedder = TextEmbedder::Create(std::move(cpp_options));
 | 
			
		||||
  if (!embedder.ok()) {
 | 
			
		||||
    ABSL_LOG(ERROR) << "Failed to create TextEmbedder: "
 | 
			
		||||
                    << embedder.status();
 | 
			
		||||
    CppProcessError(embedder.status(), error_msg);
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
  return embedder->release();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int CppTextEmbedderEmbed(void* embedder, const char* utf8_str,
 | 
			
		||||
                              TextEmbedderResult* result, char** error_msg) {
 | 
			
		||||
  auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
 | 
			
		||||
  auto cpp_result = cpp_embedder->Embed(utf8_str);
 | 
			
		||||
  if (!cpp_result.ok()) {
 | 
			
		||||
    ABSL_LOG(ERROR) << "Embedding extraction failed: " << cpp_result.status();
 | 
			
		||||
    return CppProcessError(cpp_result.status(), error_msg);
 | 
			
		||||
  }
 | 
			
		||||
  CppConvertToEmbeddingResult(*cpp_result, result);
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void CppTextEmbedderCloseResult(TextEmbedderResult* result) {
 | 
			
		||||
  CppCloseEmbeddingResult(result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int CppTextEmbedderClose(void* embedder, char** error_msg) {
 | 
			
		||||
  auto cpp_embedder = static_cast<TextEmbedder*>(embedder);
 | 
			
		||||
  auto result = cpp_embedder->Close();
 | 
			
		||||
  if (!result.ok()) {
 | 
			
		||||
    ABSL_LOG(ERROR) << "Failed to close TextEmbedder: " << result;
 | 
			
		||||
    return CppProcessError(result, error_msg);
 | 
			
		||||
  }
 | 
			
		||||
  delete cpp_embedder;
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::c::text::text_embedder
 | 
			
		||||
 | 
			
		||||
extern "C" {
 | 
			
		||||
 | 
			
		||||
void* text_embedder_create(struct TextEmbedderOptions* options,
 | 
			
		||||
                             char** error_msg) {
 | 
			
		||||
  return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
 | 
			
		||||
      *options, error_msg);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int text_embedder_embed(void* embedder, const char* utf8_str,
 | 
			
		||||
                             TextEmbedderResult* result, char** error_msg) {
 | 
			
		||||
  return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
 | 
			
		||||
      embedder, utf8_str, result, error_msg);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void text_embedder_close_result(TextEmbedderResult* result) {
 | 
			
		||||
  mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCloseResult(
 | 
			
		||||
      result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int text_embedder_close(void* embedder, char** error_ms) {
 | 
			
		||||
  return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderClose(
 | 
			
		||||
      embedder, error_ms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // extern "C"
 | 
			
		||||
							
								
								
									
										75
									
								
								mediapipe/tasks/c/text/text_embedder/text_embedder.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								mediapipe/tasks/c/text/text_embedder/text_embedder.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,75 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_EMBEDDER_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_EMBEDDER_H_
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
 | 
			
		||||
#include "mediapipe/tasks/c/components/processors/embedder_options.h"
 | 
			
		||||
#include "mediapipe/tasks/c/core/base_options.h"
 | 
			
		||||
 | 
			
		||||
#ifndef MP_EXPORT
 | 
			
		||||
#define MP_EXPORT __attribute__((visibility("default")))
 | 
			
		||||
#endif  // MP_EXPORT
 | 
			
		||||
 | 
			
		||||
#ifdef __cplusplus
 | 
			
		||||
extern "C" {
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
typedef struct EmbeddingResult TextEmbedderResult;
 | 
			
		||||
 | 
			
		||||
// The options for configuring a MediaPipe text embedder task.
 | 
			
		||||
struct TextEmbedderOptions {
 | 
			
		||||
  // Base options for configuring MediaPipe Tasks, such as specifying the model
 | 
			
		||||
  // file with metadata, accelerator options, op resolver, etc.
 | 
			
		||||
  struct BaseOptions base_options;
 | 
			
		||||
 | 
			
		||||
  // Options for configuring the embedder behavior, such as score threshold,
 | 
			
		||||
  // number of results, etc.
 | 
			
		||||
  struct EmbedderOptions embedder_options;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Creates a TextEmbedder from the provided `options`.
 | 
			
		||||
// Returns a pointer to the text embedder on success.
 | 
			
		||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
 | 
			
		||||
// an error message (if `error_msg` is not nullptr). You must free the memory
 | 
			
		||||
// allocated for the error message.
 | 
			
		||||
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
 | 
			
		||||
                                       char** error_msg = nullptr);
 | 
			
		||||
 | 
			
		||||
// Performs embedding extraction on the input `text`. Returns `0` on success.
 | 
			
		||||
// If an error occurs, returns an error code and sets the error parameter to an
 | 
			
		||||
// an error message (if `error_msg` is not nullptr). You must free the memory
 | 
			
		||||
// allocated for the error message.
 | 
			
		||||
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
 | 
			
		||||
                                  TextEmbedderResult* result,
 | 
			
		||||
                                  char** error_msg = nullptr);
 | 
			
		||||
 | 
			
		||||
// Frees the memory allocated inside a TextEmbedderResult result. Does not
 | 
			
		||||
// free the result pointer itself.
 | 
			
		||||
MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
 | 
			
		||||
 | 
			
		||||
// Shuts down the TextEmbedder when all the work is done. Frees all memory.
 | 
			
		||||
// If an error occurs, returns an error code and sets the error parameter to an
 | 
			
		||||
// an error message (if `error_msg` is not nullptr). You must free the memory
 | 
			
		||||
// allocated for the error message.
 | 
			
		||||
MP_EXPORT int text_embedder_close(void* embedder,
 | 
			
		||||
                                  char** error_msg = nullptr);
 | 
			
		||||
 | 
			
		||||
#ifdef __cplusplus
 | 
			
		||||
}  // extern C
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_EMBEDDER_H_
 | 
			
		||||
							
								
								
									
										80
									
								
								mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,80 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/c/text/text_embedder/text_embedder.h"
 | 
			
		||||
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "absl/flags/flag.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/framework/deps/file_path.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::file::JoinPath;
 | 
			
		||||
using testing::HasSubstr;
 | 
			
		||||
 | 
			
		||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
 | 
			
		||||
constexpr char kTestBertModelPath[] = "mobilebert_embedding_with_metadata.tflite";
 | 
			
		||||
constexpr char kTestString[] = "It's beautiful outside.";
 | 
			
		||||
constexpr float kPrecision = 1e-6;
 | 
			
		||||
 | 
			
		||||
std::string GetFullPath(absl::string_view file_name) {
 | 
			
		||||
  return JoinPath("./", kTestDataDirectory, file_name);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TextEmbedderTest, SmokeTest) {
 | 
			
		||||
  std::string model_path = GetFullPath(kTestBertModelPath);
 | 
			
		||||
  TextEmbedderOptions options = {
 | 
			
		||||
      /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | 
			
		||||
                           /* model_asset_path= */ model_path.c_str()},
 | 
			
		||||
      /* embedder_options= */
 | 
			
		||||
      {/* l2_normalize= */ false,
 | 
			
		||||
       /* quantize= */ true},
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  void* embedder = text_embedder_create(&options);
 | 
			
		||||
  EXPECT_NE(embedder, nullptr);
 | 
			
		||||
 | 
			
		||||
  TextEmbedderResult result;
 | 
			
		||||
  text_embedder_embed(embedder, kTestString, &result);
 | 
			
		||||
  EXPECT_EQ(result.embeddings_count, 1);
 | 
			
		||||
  EXPECT_EQ(result.embeddings[0].values_count, 512);
 | 
			
		||||
 | 
			
		||||
  text_embedder_close_result(&result);
 | 
			
		||||
  text_embedder_close(embedder);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TextEmbedderTest, ErrorHandling) {
 | 
			
		||||
  // It is an error to set neither the asset buffer nor the path.
 | 
			
		||||
  TextEmbedderOptions options = {
 | 
			
		||||
      /* base_options= */ {/* model_asset_buffer= */ nullptr,
 | 
			
		||||
                           /* model_asset_path= */ nullptr},
 | 
			
		||||
      /* embedder_options= */ {},
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  char* error_msg;
 | 
			
		||||
  void* embedder = text_embedder_create(&options, &error_msg);
 | 
			
		||||
  EXPECT_EQ(embedder, nullptr);
 | 
			
		||||
 | 
			
		||||
  EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT"));
 | 
			
		||||
 | 
			
		||||
  free(error_msg);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user