Resolved some issues
This commit is contained in:
parent
753ba916a1
commit
92e13d43e4
|
@ -29,11 +29,11 @@ extern "C" {
|
||||||
// contain data, based on whether or not the embedder was configured to perform
|
// contain data, based on whether or not the embedder was configured to perform
|
||||||
// scalar quantization.
|
// scalar quantization.
|
||||||
struct Embedding {
|
struct Embedding {
|
||||||
// Floating-point embedding. Empty if the embedder was configured to perform
|
// Floating-point embedding. Empty/nullptr if the embedder was configured to perform
|
||||||
// scalar-quantization.
|
// scalar-quantization.
|
||||||
float* float_embedding;
|
float* float_embedding;
|
||||||
|
|
||||||
// Scalar-quantized embedding. Empty if the embedder was not configured to
|
// Scalar-quantized embedding. Empty/nullptr if the embedder was not configured to
|
||||||
// perform scalar quantization.
|
// perform scalar quantization.
|
||||||
char* quantized_embedding;
|
char* quantized_embedding;
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ struct Embedding {
|
||||||
|
|
||||||
// The optional name of the embedder head, as provided in the TFLite Model
|
// The optional name of the embedder head, as provided in the TFLite Model
|
||||||
// Metadata [1] if present. This is useful for multi-head models.
|
// Metadata [1] if present. This is useful for multi-head models.
|
||||||
|
// Defaults to nullptr.
|
||||||
//
|
//
|
||||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||||
char* head_name;
|
char* head_name;
|
||||||
|
|
|
@ -70,10 +70,13 @@ void CppCloseEmbeddingResult(EmbeddingResult* in) {
|
||||||
|
|
||||||
delete[] embedding_in.float_embedding;
|
delete[] embedding_in.float_embedding;
|
||||||
delete[] embedding_in.quantized_embedding;
|
delete[] embedding_in.quantized_embedding;
|
||||||
|
embedding_in.float_embedding = nullptr;
|
||||||
|
embedding_in.quantized_embedding = nullptr;
|
||||||
|
|
||||||
free(embedding_in.head_name);
|
free(embedding_in.head_name);
|
||||||
}
|
}
|
||||||
delete[] in->embeddings;
|
delete[] in->embeddings;
|
||||||
|
in.head_name = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
|
@ -55,6 +55,9 @@ TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) {
|
||||||
EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
|
EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
|
||||||
EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
|
EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
|
||||||
EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
|
EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
|
||||||
|
EXPECT_NE(c_embedding_result.embeddings[1].quantized_embedding, nullptr);
|
||||||
|
EXPECT_EQ(c_embedding_result.embeddings[1].values_count, 5);
|
||||||
|
EXPECT_EQ(c_embedding_result.embeddings[1].head_index, 0);
|
||||||
EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
|
EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
|
||||||
EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
|
EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
|
||||||
EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);
|
EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
|
@ -100,13 +100,13 @@ int CppTextEmbedderClose(void* embedder, char** error_msg) {
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
void* text_embedder_create(struct TextEmbedderOptions* options,
|
void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||||
char** error_msg) {
|
char** error_msg) {
|
||||||
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
|
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
|
||||||
*options, error_msg);
|
*options, error_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
int text_embedder_embed(void* embedder, const char* utf8_str,
|
int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||||
TextEmbedderResult* result, char** error_msg) {
|
TextEmbedderResult* result, char** error_msg) {
|
||||||
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
|
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
|
||||||
embedder, utf8_str, result, error_msg);
|
embedder, utf8_str, result, error_msg);
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,22 +36,22 @@ struct TextEmbedderOptions {
|
||||||
// file with metadata, accelerator options, op resolver, etc.
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
struct BaseOptions base_options;
|
struct BaseOptions base_options;
|
||||||
|
|
||||||
// Options for configuring the embedder behavior, such as score threshold,
|
// Options for configuring the embedder behavior, such as l2_normalize
|
||||||
// number of results, etc.
|
// and quantize.
|
||||||
struct EmbedderOptions embedder_options;
|
struct EmbedderOptions embedder_options;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates a TextEmbedder from the provided `options`.
|
// Creates a TextEmbedder from the provided `options`.
|
||||||
// Returns a pointer to the text embedder on success.
|
// Returns a pointer to the text embedder on success.
|
||||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||||
// allocated for the error message.
|
// allocated for the error message.
|
||||||
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
|
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||||
char** error_msg = nullptr);
|
char** error_msg = nullptr);
|
||||||
|
|
||||||
// Performs embedding extraction on the input `text`. Returns `0` on success.
|
// Performs embedding extraction on the input `text`. Returns `0` on success.
|
||||||
// If an error occurs, returns an error code and sets the error parameter to an
|
// If an error occurs, returns an error code and sets the error parameter to an
|
||||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||||
// allocated for the error message.
|
// allocated for the error message.
|
||||||
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
|
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||||
TextEmbedderResult* result,
|
TextEmbedderResult* result,
|
||||||
|
@ -63,7 +63,7 @@ MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
|
||||||
|
|
||||||
// Shuts down the TextEmbedder when all the work is done. Frees all memory.
|
// Shuts down the TextEmbedder when all the work is done. Frees all memory.
|
||||||
// If an error occurs, returns an error code and sets the error parameter to an
|
// If an error occurs, returns an error code and sets the error parameter to an
|
||||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||||
// allocated for the error message.
|
// allocated for the error message.
|
||||||
MP_EXPORT int text_embedder_close(void* embedder,
|
MP_EXPORT int text_embedder_close(void* embedder,
|
||||||
char** error_msg = nullptr);
|
char** error_msg = nullptr);
|
||||||
|
|
|
@ -44,8 +44,7 @@ TEST(TextEmbedderTest, SmokeTest) {
|
||||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||||
/* model_asset_path= */ model_path.c_str()},
|
/* model_asset_path= */ model_path.c_str()},
|
||||||
/* embedder_options= */
|
/* embedder_options= */
|
||||||
{/* l2_normalize= */ false,
|
{/* l2_normalize= */ false, /* quantize= */ true},
|
||||||
/* quantize= */ true},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void* embedder = text_embedder_create(&options);
|
void* embedder = text_embedder_create(&options);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user