Resolved some issues
This commit is contained in:
parent
753ba916a1
commit
92e13d43e4
|
@ -29,11 +29,11 @@ extern "C" {
|
|||
// contain data, based on whether or not the embedder was configured to perform
|
||||
// scalar quantization.
|
||||
struct Embedding {
|
||||
// Floating-point embedding. Empty if the embedder was configured to perform
|
||||
// Floating-point embedding. Empty/nullptr if the embedder was configured to perform
|
||||
// scalar-quantization.
|
||||
float* float_embedding;
|
||||
|
||||
// Scalar-quantized embedding. Empty if the embedder was not configured to
|
||||
// Scalar-quantized embedding. Empty/nullptr if the embedder was not configured to
|
||||
// perform scalar quantization.
|
||||
char* quantized_embedding;
|
||||
|
||||
|
@ -46,6 +46,7 @@ struct Embedding {
|
|||
|
||||
// The optional name of the embedder head, as provided in the TFLite Model
|
||||
// Metadata [1] if present. This is useful for multi-head models.
|
||||
// Defaults to nullptr.
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
char* head_name;
|
||||
|
|
|
@ -70,10 +70,13 @@ void CppCloseEmbeddingResult(EmbeddingResult* in) {
|
|||
|
||||
delete[] embedding_in.float_embedding;
|
||||
delete[] embedding_in.quantized_embedding;
|
||||
embedding_in.float_embedding = nullptr;
|
||||
embedding_in.quantized_embedding = nullptr;
|
||||
|
||||
free(embedding_in.head_name);
|
||||
}
|
||||
delete[] in->embeddings;
|
||||
in.head_name = nullptr;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
|
|
@ -55,6 +55,9 @@ TEST(EmbeddingResultConverterTest, ConvertsEmbeddingResultCustomEmbedding) {
|
|||
EXPECT_NE(c_embedding_result.embeddings[0].float_embedding, nullptr);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[0].values_count, 5);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[0].head_index, 0);
|
||||
EXPECT_NE(c_embedding_result.embeddings[1].quantized_embedding, nullptr);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[1].values_count, 5);
|
||||
EXPECT_EQ(c_embedding_result.embeddings[1].head_index, 0);
|
||||
EXPECT_EQ(std::string(c_embedding_result.embeddings[0].head_name), "foo");
|
||||
EXPECT_EQ(c_embedding_result.timestamp_ms, 42);
|
||||
EXPECT_EQ(c_embedding_result.has_timestamp_ms, true);
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -100,13 +100,13 @@ int CppTextEmbedderClose(void* embedder, char** error_msg) {
|
|||
extern "C" {
|
||||
|
||||
void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||
char** error_msg) {
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderCreate(
|
||||
*options, error_msg);
|
||||
}
|
||||
|
||||
int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||
TextEmbedderResult* result, char** error_msg) {
|
||||
TextEmbedderResult* result, char** error_msg) {
|
||||
return mediapipe::tasks::c::text::text_embedder::CppTextEmbedderEmbed(
|
||||
embedder, utf8_str, result, error_msg);
|
||||
}
|
||||
|
|
|
@ -36,22 +36,22 @@ struct TextEmbedderOptions {
|
|||
// file with metadata, accelerator options, op resolver, etc.
|
||||
struct BaseOptions base_options;
|
||||
|
||||
// Options for configuring the embedder behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
// Options for configuring the embedder behavior, such as l2_normalize
|
||||
// and quantize.
|
||||
struct EmbedderOptions embedder_options;
|
||||
};
|
||||
|
||||
// Creates a TextEmbedder from the provided `options`.
|
||||
// Returns a pointer to the text embedder on success.
|
||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT void* text_embedder_create(struct TextEmbedderOptions* options,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
// Performs embedding extraction on the input `text`. Returns `0` on success.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int text_embedder_embed(void* embedder, const char* utf8_str,
|
||||
TextEmbedderResult* result,
|
||||
|
@ -63,7 +63,7 @@ MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
|
|||
|
||||
// Shuts down the TextEmbedder when all the work is done. Frees all memory.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int text_embedder_close(void* embedder,
|
||||
char** error_msg = nullptr);
|
||||
|
|
|
@ -44,8 +44,7 @@ TEST(TextEmbedderTest, SmokeTest) {
|
|||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* embedder_options= */
|
||||
{/* l2_normalize= */ false,
|
||||
/* quantize= */ true},
|
||||
{/* l2_normalize= */ false, /* quantize= */ true},
|
||||
};
|
||||
|
||||
void* embedder = text_embedder_create(&options);
|
||||
|
|
Loading…
Reference in New Issue
Block a user