Resolved some issues

This commit is contained in:
Kinar 2023-10-05 03:54:54 -07:00
parent 753ba916a1
commit 92e13d43e4
7 changed files with 18 additions and 11 deletions

View File

@ -29,11 +29,11 @@ extern "C" {
// contain data, based on whether or not the embedder was configured to perform // contain data, based on whether or not the embedder was configured to perform
// scalar quantization. // scalar quantization.
struct Embedding { struct Embedding {
// Floating-point embedding. Empty if the embedder was configured to perform // Floating-point embedding. Empty/nullptr if the embedder was configured to perform
// scalar-quantization. // scalar-quantization.
float* float_embedding; float* float_embedding;
// Scalar-quantized embedding. Empty if the embedder was not configured to // Scalar-quantized embedding. Empty/nullptr if the embedder was not configured to
// perform scalar quantization. // perform scalar quantization.
char* quantized_embedding; char* quantized_embedding;
@ -46,6 +46,7 @@ struct Embedding {
// The optional name of the embedder head, as provided in the TFLite Model // The optional name of the embedder head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models. // Metadata [1] if present. This is useful for multi-head models.
// Defaults to nullptr.
// //
// [1]: https://www.tensorflow.org/lite/convert/metadata // [1]: https://www.tensorflow.org/lite/convert/metadata
char* head_name; char* head_name;

View File

@ -70,10 +70,13 @@ void CppCloseEmbeddingResult(EmbeddingResult* in) {
delete[] embedding_in.float_embedding; delete[] embedding_in.float_embedding;
delete[] embedding_in.quantized_embedding; delete[] embedding_in.quantized_embedding;
embedding_in.float_embedding = nullptr;
embedding_in.quantized_embedding = nullptr;
free(embedding_in.head_name); free(embedding_in.head_name);
} }
delete[] in->embeddings; delete[] in->embeddings;
in.head_name = nullptr;
} }
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers

View File

@ -55,6 +55,9 @@ TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) {
EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr); EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5); EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0); EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
EXPECT_NE(c_embedding_result.embeddings[1].quantized_embedding, nullptr);
EXPECT_EQ(c_embedding_result.embeddings[1].values_count, 5);
EXPECT_EQ(c_embedding_result.embeddings[1].head_index, 0);
EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo"); EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
EXPECT_EQ(c_embedding_result.timestamp_ms, 42); EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
EXPECT_EQ(c_embedding_result.has_timestamp_ms, true); EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);

View File

@ -17,6 +17,7 @@ limitations under the License.
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_ #define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#include <stdint.h> #include <stdint.h>
#include <stdbool.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {

View File

@ -100,13 +100,13 @@ int CppTextEmbedderClose(void* embedder, char** error_msg) {
extern "C" { extern "C" {
void* text_embedder_create(struct TextEmbedderOptions* options, void* text_embedder_create(struct TextEmbedderOptions* options,
char** error_msg) { char** error_msg) {
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate( return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
*options, error_msg); *options, error_msg);
} }
int text_embedder_embed(void* embedder, const char* utf8_str, int text_embedder_embed(void* embedder, const char* utf8_str,
TextEmbedderResult* result, char** error_msg) { TextEmbedderResult* result, char** error_msg) {
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed( return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
embedder, utf8_str, result, error_msg); embedder, utf8_str, result, error_msg);
} }

View File

@ -36,22 +36,22 @@ struct TextEmbedderOptions {
// file with metadata, accelerator options, op resolver, etc. // file with metadata, accelerator options, op resolver, etc.
struct BaseOptions base_options; struct BaseOptions base_options;
// Options for configuring the embedder behavior, such as score threshold, // Options for configuring the embedder behavior, such as l2_normalize
// number of results, etc. // and quantize.
struct EmbedderOptions embedder_options; struct EmbedderOptions embedder_options;
}; };
// Creates a TextEmbedder from the provided `options`. // Creates a TextEmbedder from the provided `options`.
// Returns a pointer to the text embedder on success. // Returns a pointer to the text embedder on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an // If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options, MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
char** error_msg = nullptr); char** error_msg = nullptr);
// Performs embedding extraction on the input `text`. Returns `0` on success. // Performs embedding extraction on the input `text`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str, MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
TextEmbedderResult* result, TextEmbedderResult* result,
@ -63,7 +63,7 @@ MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
// Shuts down the TextEmbedder when all the work is done. Frees all memory. // Shuts down the TextEmbedder when all the work is done. Frees all memory.
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message. // allocated for the error message.
MP_EXPORT int text_embedder_close(void* embedder, MP_EXPORT int text_embedder_close(void* embedder,
char** error_msg = nullptr); char** error_msg = nullptr);

View File

@ -44,8 +44,7 @@ TEST(TextEmbedderTest, SmokeTest) {
/* base_options= */ {/* model_asset_buffer= */ nullptr, /* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ model_path.c_str()}, /* model_asset_path= */ model_path.c_str()},
/* embedder_options= */ /* embedder_options= */
{/* l2_normalize= */ false, {/* l2_normalize= */ false, /* quantize= */ true},
/* quantize= */ true},
}; };
void* embedder = text_embedder_create(&options); void* embedder = text_embedder_create(&options);