Updated Text Embedder API

This commit is contained in:
kinaryml 2022-11-12 07:42:46 -08:00
parent acd9c280c0
commit a8103629c7
2 changed files with 4 additions and 5 deletions

View File

@ -43,7 +43,6 @@ py_library(
"text_embedder.py", "text_embedder.py",
], ],
deps = [ deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2",

View File

@ -14,11 +14,9 @@
"""MediaPipe text embedder task.""" """MediaPipe text embedder task."""
import dataclasses import dataclasses
from typing import Callable, Mapping, Optional
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.processors import embedder_options
@ -70,7 +68,7 @@ class TextEmbedder(base_text_task_api.BaseTextTaskApi):
"""Class that performs embedding extraction on text.""" """Class that performs embedding extraction on text."""
@classmethod @classmethod
def create_from_model_path(cls, model_path: str) -> 'ImageEmbedder': def create_from_model_path(cls, model_path: str) -> 'TextEmbedder':
"""Creates an `TextEmbedder` object from a TensorFlow Lite model and the """Creates an `TextEmbedder` object from a TensorFlow Lite model and the
default `TextEmbedderOptions`. default `TextEmbedderOptions`.
@ -147,7 +145,9 @@ class TextEmbedder(base_text_task_api.BaseTextTaskApi):
def cosine_similarity(u: embedding_result_module.Embedding, def cosine_similarity(u: embedding_result_module.Embedding,
v: embedding_result_module.Embedding) -> float: v: embedding_result_module.Embedding) -> float:
"""Utility function to compute cosine similarity [1] between two embedding """Utility function to compute cosine similarity [1] between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are entries.
May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a of different types (quantized vs. float), have different sizes, or have a
an L2-norm of 0. an L2-norm of 0.