Updated Text Embedder API
This commit is contained in:
parent
acd9c280c0
commit
a8103629c7
|
@ -43,7 +43,6 @@ py_library(
|
||||||
"text_embedder.py",
|
"text_embedder.py",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/python:_framework_bindings",
|
|
||||||
"//mediapipe/python:packet_creator",
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2",
|
||||||
|
|
|
@ -14,11 +14,9 @@
|
||||||
"""MediaPipe text embedder task."""
|
"""MediaPipe text embedder task."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Callable, Mapping, Optional
|
|
||||||
|
|
||||||
from mediapipe.python import packet_creator
|
from mediapipe.python import packet_creator
|
||||||
from mediapipe.python import packet_getter
|
from mediapipe.python import packet_getter
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
|
||||||
from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2
|
from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2
|
||||||
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||||
from mediapipe.tasks.python.components.processors import embedder_options
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
|
@ -70,7 +68,7 @@ class TextEmbedder(base_text_task_api.BaseTextTaskApi):
|
||||||
"""Class that performs embedding extraction on text."""
|
"""Class that performs embedding extraction on text."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_from_model_path(cls, model_path: str) -> 'ImageEmbedder':
|
def create_from_model_path(cls, model_path: str) -> 'TextEmbedder':
|
||||||
"""Creates an `TextEmbedder` object from a TensorFlow Lite model and the
|
"""Creates an `TextEmbedder` object from a TensorFlow Lite model and the
|
||||||
default `TextEmbedderOptions`.
|
default `TextEmbedderOptions`.
|
||||||
|
|
||||||
|
@ -147,7 +145,9 @@ class TextEmbedder(base_text_task_api.BaseTextTaskApi):
|
||||||
def cosine_similarity(u: embedding_result_module.Embedding,
|
def cosine_similarity(u: embedding_result_module.Embedding,
|
||||||
v: embedding_result_module.Embedding) -> float:
|
v: embedding_result_module.Embedding) -> float:
|
||||||
"""Utility function to compute cosine similarity [1] between two embedding
|
"""Utility function to compute cosine similarity [1] between two embedding
|
||||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
entries.
|
||||||
|
|
||||||
|
May return an InvalidArgumentError if e.g. the feature vectors are
|
||||||
of different types (quantized vs. float), have different sizes, or have a
|
of different types (quantized vs. float), have different sizes, or have a
|
||||||
an L2-norm of 0.
|
an L2-norm of 0.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user