Added files needed for the text embedder's implementation and tests
This commit is contained in:
parent
0ac604d507
commit
1604908a59
|
@ -98,6 +98,7 @@ cc_library(
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||||
|
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
|
@ -104,3 +104,12 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "embedding_result",
|
||||||
|
srcs = ["embedding_result.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
210
mediapipe/tasks/python/components/containers/embedding_result.py
Normal file
210
mediapipe/tasks/python/components/containers/embedding_result.py
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Embeddings data class."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding
|
||||||
|
_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding
|
||||||
|
_EmbeddingProto = embeddings_pb2.Embedding
|
||||||
|
_EmbeddingResultProto = embeddings_pb2.EmbeddingResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class FloatEmbedding:
|
||||||
|
"""Defines a dense floating-point embedding.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
values: A NumPy array indicating the raw output of the embedding layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
values: np.ndarray
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _FloatEmbeddingProto:
|
||||||
|
"""Generates a FloatEmbedding protobuf object."""
|
||||||
|
return _FloatEmbeddingProto(values=self.values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _FloatEmbeddingProto) -> 'FloatEmbedding':
|
||||||
|
"""Creates a `FloatEmbedding` object from the given protobuf object."""
|
||||||
|
return FloatEmbedding(values=np.array(pb2_obj.value_float, dtype=float))
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, FloatEmbedding):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class QuantizedEmbedding:
|
||||||
|
"""Defines a dense scalar-quantized embedding.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
values: A NumPy array indicating the raw output of the embedding layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
values: np.ndarray
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _QuantizedEmbeddingProto:
|
||||||
|
"""Generates a QuantizedEmbedding protobuf object."""
|
||||||
|
return _QuantizedEmbeddingProto(values=self.values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _QuantizedEmbeddingProto) -> 'QuantizedEmbedding':
|
||||||
|
"""Creates a `QuantizedEmbedding` object from the given protobuf object."""
|
||||||
|
return QuantizedEmbedding(
|
||||||
|
values=np.array(bytearray(pb2_obj.value_string), dtype=np.uint8))
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, QuantizedEmbedding):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Embedding:
|
||||||
|
"""Embedding result for a given embedder head.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embedding: The actual embedding, either floating-point or scalar-quantized.
|
||||||
|
head_index: The index of the embedder head that produced this embedding.
|
||||||
|
This is useful for multi-head models.
|
||||||
|
head_name: The name of the embedder head, which is the corresponding tensor
|
||||||
|
metadata name (if any). This is useful for multi-head models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding: np.ndarray
|
||||||
|
head_index: Optional[int] = None
|
||||||
|
head_name: Optional[str] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _EmbeddingProto:
|
||||||
|
"""Generates a Embedding protobuf object."""
|
||||||
|
|
||||||
|
if self.embedding.dtype == float:
|
||||||
|
return _EmbeddingProto(float_embedding=self.embedding,
|
||||||
|
head_index=self.head_index,
|
||||||
|
head_name=self.head_name)
|
||||||
|
|
||||||
|
elif self.embedding.dtype == np.uint8:
|
||||||
|
return _EmbeddingProto(quantized_embedding=bytes(self.embedding),
|
||||||
|
head_index=self.head_index,
|
||||||
|
head_name=self.head_name)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid dtype. Only float and np.uint8 are supported.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _EmbeddingProto) -> 'Embedding':
|
||||||
|
"""Creates a `Embedding` object from the given protobuf object."""
|
||||||
|
|
||||||
|
quantized_embedding = np.array(
|
||||||
|
bytearray(pb2_obj.quantized_embedding.values))
|
||||||
|
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
|
||||||
|
|
||||||
|
if len(quantized_embedding) == 0:
|
||||||
|
return Embedding(embedding=float_embedding,
|
||||||
|
head_index=pb2_obj.head_index,
|
||||||
|
head_name=pb2_obj.head_name)
|
||||||
|
else:
|
||||||
|
return Embedding(embedding=quantized_embedding,
|
||||||
|
head_index=pb2_obj.head_index,
|
||||||
|
head_name=pb2_obj.head_name)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, Embedding):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class EmbeddingResult:
|
||||||
|
"""Embedding results for a given embedder model.
|
||||||
|
Attributes:
|
||||||
|
embeddings: A list of `Embedding` objects.
|
||||||
|
timestamp_ms: The optional timestamp (in milliseconds) of the start of the
|
||||||
|
chunk of data corresponding to these results. This is only used for
|
||||||
|
embedding extraction on time series (e.g. audio embedding). In these use
|
||||||
|
cases, the amount of data to process might exceed the maximum size that
|
||||||
|
the model can process: to solve this, the input data is split into
|
||||||
|
multiple chunks starting at different timestamps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embeddings: List[Embedding]
|
||||||
|
timestamp_ms: Optional[int] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _EmbeddingResultProto:
|
||||||
|
"""Generates a EmbeddingResult protobuf object."""
|
||||||
|
return _EmbeddingResultProto(
|
||||||
|
embeddings=[
|
||||||
|
embedding.to_pb2() for embedding in self.embeddings
|
||||||
|
])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _EmbeddingResultProto) -> 'EmbeddingResult':
|
||||||
|
"""Creates a `EmbeddingResult` object from the given protobuf object."""
|
||||||
|
return EmbeddingResult(
|
||||||
|
embeddings=[
|
||||||
|
Embedding.create_from_pb2(embedding)
|
||||||
|
for embedding in pb2_obj.embeddings
|
||||||
|
])
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, EmbeddingResult):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -28,3 +28,12 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "embedder_options",
|
||||||
|
srcs = ["embedder_options.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Embedder options data class."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class EmbedderOptions:
|
||||||
|
"""Shared options used by all embedding extraction tasks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
|
||||||
|
Use this option only if the model does not already contain a native
|
||||||
|
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
|
||||||
|
L2 norm is thus achieved through TF Lite inference.
|
||||||
|
quantize: Whether the returned embedding should be quantized to bytes via
|
||||||
|
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
|
||||||
|
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
|
||||||
|
the l2_normalize option if this is not the case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
l2_normalize: Optional[bool] = None
|
||||||
|
quantize: Optional[bool] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _EmbedderOptionsProto:
|
||||||
|
"""Generates a EmbedderOptions protobuf object."""
|
||||||
|
return _EmbedderOptionsProto(
|
||||||
|
l2_normalize=self.l2_normalize,
|
||||||
|
quantize=self.quantize)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions':
|
||||||
|
"""Creates a `EmbedderOptions` object from the given protobuf object."""
|
||||||
|
return EmbedderOptions(
|
||||||
|
l2_normalize=pb2_obj.l2_normalize,
|
||||||
|
quantize=pb2_obj.quantize)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, EmbedderOptions):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
28
mediapipe/tasks/python/components/utils/BUILD
Normal file
28
mediapipe/tasks/python/components/utils/BUILD
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cosine_similarity",
|
||||||
|
srcs = ["cosine_similarity.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||||
|
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/tasks/python/components/utils/__init__.py
Normal file
13
mediapipe/tasks/python/components/utils/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
61
mediapipe/tasks/python/components/utils/cosine_similarity.py
Normal file
61
mediapipe/tasks/python/components/utils/cosine_similarity.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Cosine similarity utilities."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.components.containers import embedding_result
|
||||||
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
|
|
||||||
|
_Embedding = embedding_result.Embedding
|
||||||
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_cosine_similarity(u, v):
|
||||||
|
if len(u.embedding) <= 0:
|
||||||
|
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
|
||||||
|
|
||||||
|
norm_u = np.linalg.norm(u.embedding)
|
||||||
|
norm_v = np.linalg.norm(v.embedding)
|
||||||
|
|
||||||
|
if norm_u <= 0 or norm_v <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot compute cosine similarity on embedding with 0 norm.")
|
||||||
|
|
||||||
|
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
||||||
|
"""Utility function to compute cosine similarity between two embedding.
|
||||||
|
May return an InvalidArgumentError if e.g. the feature vectors are of
|
||||||
|
different types (quantized vs. float), have different sizes, or have an
|
||||||
|
L2-norm of 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
u: An embedding.
|
||||||
|
v: An embedding.
|
||||||
|
"""
|
||||||
|
if len(u.embedding) != len(v.embedding):
|
||||||
|
raise ValueError(f"Cannot compute cosine similarity between embeddings "
|
||||||
|
f"of different sizes "
|
||||||
|
f"({len(u.embedding)} vs. {len(v.embedding)}).")
|
||||||
|
|
||||||
|
if u.embedding.dtype == float and v.embedding.dtype == float:
|
||||||
|
return _compute_cosine_similarity(u, v)
|
||||||
|
|
||||||
|
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
|
||||||
|
return _compute_cosine_similarity(u, v)
|
||||||
|
|
||||||
|
raise ValueError("Cannot compute cosine similarity between quantized and "
|
||||||
|
"float embeddings.")
|
|
@ -34,3 +34,20 @@ py_test(
|
||||||
"//mediapipe/tasks/python/text:text_classifier",
|
"//mediapipe/tasks/python/text:text_classifier",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "text_embedder_test",
|
||||||
|
srcs = ["text_embedder_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||||
|
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||||
|
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||||
|
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/text:text_embedder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
207
mediapipe/tasks/python/test/text/text_embedder_test.py
Normal file
207
mediapipe/tasks/python/test/text/text_embedder_test.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for text embedder."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import os
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||||
|
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.text import text_embedder
|
||||||
|
|
||||||
|
ImageEmbedderResult = embedding_result_module.EmbeddingResult
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||||
|
_FloatEmbedding = embedding_result_module.FloatEmbedding
|
||||||
|
_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding
|
||||||
|
_Embedding = embedding_result_module.Embedding
|
||||||
|
_TextEmbedder = text_embedder.TextEmbedder
|
||||||
|
_TextEmbedderOptions = text_embedder.TextEmbedderOptions
|
||||||
|
|
||||||
|
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
|
||||||
|
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||||
|
# Tolerance for embedding vector coordinate values.
|
||||||
|
_EPSILON = 1e-4
|
||||||
|
# Tolerance for cosine similarity evaluation.
|
||||||
|
_SIMILARITY_TOLERANCE = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileType(enum.Enum):
|
||||||
|
FILE_CONTENT = 1
|
||||||
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ImageEmbedderTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE))
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _TextEmbedder.create_from_model_path(self.model_path) as embedder:
|
||||||
|
self.assertIsInstance(embedder, _TextEmbedder)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _TextEmbedderOptions(base_options=base_options)
|
||||||
|
with _TextEmbedder.create_from_options(options) as embedder:
|
||||||
|
self.assertIsInstance(embedder, _TextEmbedder)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||||
|
base_options = _BaseOptions(
|
||||||
|
model_asset_path='/path/to/invalid/model.tflite')
|
||||||
|
options = _TextEmbedderOptions(base_options=base_options)
|
||||||
|
_TextEmbedder.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _TextEmbedderOptions(base_options=base_options)
|
||||||
|
embedder = _TextEmbedder.create_from_options(options)
|
||||||
|
self.assertIsInstance(embedder, _TextEmbedder)
|
||||||
|
|
||||||
|
def _check_embedding_value(self, result, expected_first_value):
|
||||||
|
# Check embedding first value.
|
||||||
|
self.assertAlmostEqual(result.embeddings[0].embedding[0],
|
||||||
|
expected_first_value, delta=_EPSILON)
|
||||||
|
|
||||||
|
def _check_embedding_size(self, result, quantize, expected_embedding_size):
|
||||||
|
# Check embedding size.
|
||||||
|
self.assertLen(result.embeddings, 1)
|
||||||
|
embedding_result = result.embeddings[0]
|
||||||
|
self.assertLen(embedding_result.embedding, expected_embedding_size)
|
||||||
|
if quantize:
|
||||||
|
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
|
||||||
|
else:
|
||||||
|
self.assertEqual(embedding_result.embedding.dtype, float)
|
||||||
|
|
||||||
|
def _check_cosine_similarity(self, result0, result1, expected_similarity):
|
||||||
|
# Checks cosine similarity.
|
||||||
|
similarity = _TextEmbedder.cosine_similarity(
|
||||||
|
result0.embeddings[0], result1.embeddings[0])
|
||||||
|
self.assertAlmostEqual(similarity, expected_similarity,
|
||||||
|
delta=_SIMILARITY_TOLERANCE)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME,
|
||||||
|
0.969514, 512, (19.9016, 22.626251)),
|
||||||
|
(True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME,
|
||||||
|
0.969514, 512, (0.0585837, 0.0723035)),
|
||||||
|
(False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME,
|
||||||
|
0.999937, 16, (0.0309356, 0.0312863)),
|
||||||
|
(True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT,
|
||||||
|
0.999937, 16, (0.549632, 0.552879)),
|
||||||
|
)
|
||||||
|
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
|
||||||
|
expected_similarity, expected_size, expected_first_values):
|
||||||
|
# Creates embedder.
|
||||||
|
model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, model_name))
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
embedder_options = _EmbedderOptions(l2_normalize=l2_normalize,
|
||||||
|
quantize=quantize)
|
||||||
|
options = _TextEmbedderOptions(
|
||||||
|
base_options=base_options, embedder_options=embedder_options)
|
||||||
|
embedder = _TextEmbedder.create_from_options(options)
|
||||||
|
|
||||||
|
# Extracts both embeddings.
|
||||||
|
positive_text0 = "it's a charming and often affecting journey"
|
||||||
|
positive_text1 = "what a great and fantastic trip"
|
||||||
|
|
||||||
|
result0 = embedder.embed(positive_text0)
|
||||||
|
result1 = embedder.embed(positive_text1)
|
||||||
|
|
||||||
|
# Checks embeddings and cosine similarity.
|
||||||
|
expected_result0_value, expected_result1_value = expected_first_values
|
||||||
|
self._check_embedding_size(result0, quantize, expected_size)
|
||||||
|
self._check_embedding_size(result1, quantize, expected_size)
|
||||||
|
self._check_embedding_value(result0, expected_result0_value)
|
||||||
|
self._check_embedding_value(result1, expected_result1_value)
|
||||||
|
self._check_cosine_similarity(result0, result1, expected_similarity)
|
||||||
|
# Closes the embedder explicitly when the embedder is not used in
|
||||||
|
# a context.
|
||||||
|
embedder.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME,
|
||||||
|
0.969514, 512, (19.9016, 22.626251)),
|
||||||
|
(True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME,
|
||||||
|
0.969514, 512, (0.0585837, 0.0723035)),
|
||||||
|
(False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME,
|
||||||
|
0.999937, 16, (0.0309356, 0.0312863)),
|
||||||
|
(True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT,
|
||||||
|
0.999937, 16, (0.549632, 0.552879)),
|
||||||
|
)
|
||||||
|
def test_embed_in_context(self, l2_normalize, quantize, model_name,
|
||||||
|
model_file_type, expected_similarity,
|
||||||
|
expected_size, expected_first_values):
|
||||||
|
# Creates embedder.
|
||||||
|
model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, model_name))
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
embedder_options = _EmbedderOptions(l2_normalize=l2_normalize,
|
||||||
|
quantize=quantize)
|
||||||
|
options = _TextEmbedderOptions(
|
||||||
|
base_options=base_options, embedder_options=embedder_options)
|
||||||
|
with _TextEmbedder.create_from_options(options) as embedder:
|
||||||
|
# Extracts both embeddings.
|
||||||
|
positive_text0 = "it's a charming and often affecting journey"
|
||||||
|
positive_text1 = "what a great and fantastic trip"
|
||||||
|
|
||||||
|
result0 = embedder.embed(positive_text0)
|
||||||
|
result1 = embedder.embed(positive_text1)
|
||||||
|
|
||||||
|
# Checks embeddings and cosine similarity.
|
||||||
|
expected_result0_value, expected_result1_value = expected_first_values
|
||||||
|
self._check_embedding_size(result0, quantize, expected_size)
|
||||||
|
self._check_embedding_size(result1, quantize, expected_size)
|
||||||
|
self._check_embedding_value(result0, expected_result0_value)
|
||||||
|
self._check_embedding_value(result1, expected_result1_value)
|
||||||
|
self._check_cosine_similarity(result0, result1, expected_similarity)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -36,3 +36,23 @@ py_library(
|
||||||
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_embedder",
|
||||||
|
srcs = [
|
||||||
|
"text_embedder.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||||
|
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
166
mediapipe/tasks/python/text/text_embedder.py
Normal file
166
mediapipe/tasks/python/text/text_embedder.py
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe text embedder task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Callable, Mapping, Optional
|
||||||
|
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
|
from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||||
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
|
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||||
|
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||||
|
|
||||||
|
TextEmbedderResult = embedding_result_module.EmbeddingResult
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions
|
||||||
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
|
||||||
|
_EMBEDDINGS_TAG = 'EMBEDDINGS'
|
||||||
|
_TEXT_IN_STREAM_NAME = 'text_in'
|
||||||
|
_TEXT_TAG = 'TEXT'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'
|
||||||
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TextEmbedderOptions:
|
||||||
|
"""Options for the text embedder task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the text embedder task.
|
||||||
|
embedder_options: Options for the text embedder task.
|
||||||
|
"""
|
||||||
|
base_options: _BaseOptions
|
||||||
|
embedder_options: _EmbedderOptions = _EmbedderOptions()
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _TextEmbedderGraphOptionsProto:
|
||||||
|
"""Generates an TextEmbedderOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
embedder_options_proto = self.embedder_options.to_pb2()
|
||||||
|
|
||||||
|
return _TextEmbedderGraphOptionsProto(
|
||||||
|
base_options=base_options_proto,
|
||||||
|
embedder_options=embedder_options_proto
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedder(base_text_task_api.BaseTextTaskApi):
|
||||||
|
"""Class that performs embedding extraction on text."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'ImageEmbedder':
|
||||||
|
"""Creates an `TextEmbedder` object from a TensorFlow Lite model and the
|
||||||
|
default `TextEmbedderOptions`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`TextEmbedder` object that's created from the model file and the default
|
||||||
|
`TextEmbedderOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `TextEmbedder` object from the provided
|
||||||
|
file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = TextEmbedderOptions(base_options=base_options)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(cls,
|
||||||
|
options: TextEmbedderOptions) -> 'TextEmbedder':
|
||||||
|
"""Creates the `TextEmbedder` object from text embedder options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the text embedder task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`TextEmbedder` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `TextEmbedder` object from
|
||||||
|
`TextEmbedderOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([
|
||||||
|
_EMBEDDINGS_TAG,
|
||||||
|
_EMBEDDINGS_OUT_STREAM_NAME
|
||||||
|
])
|
||||||
|
],
|
||||||
|
task_options=options)
|
||||||
|
return cls(task_info.generate_graph_config())
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
) -> TextEmbedderResult:
|
||||||
|
"""Performs text embedding extraction on the provided text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The input text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An embedding result object that contains a list of embeddings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If text embedder failed to run.
|
||||||
|
"""
|
||||||
|
output_packets = self._runner.process(
|
||||||
|
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)})
|
||||||
|
|
||||||
|
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||||
|
embedding_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||||
|
|
||||||
|
return TextEmbedderResult.create_from_pb2(embedding_result_proto)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cosine_similarity(u: embedding_result_module.Embedding,
|
||||||
|
v: embedding_result_module.Embedding) -> float:
|
||||||
|
"""Utility function to compute cosine similarity [1] between two embedding
|
||||||
|
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||||
|
of different types (quantized vs. float), have different sizes, or have a
|
||||||
|
an L2-norm of 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
u: An embedding entry.
|
||||||
|
v: An embedding entry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cosine similarity for the two embeddings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: May return an error if e.g. the feature vectors are of
|
||||||
|
different types (quantized vs. float), have different sizes, or have
|
||||||
|
an L2-norm of 0
|
||||||
|
"""
|
||||||
|
return cosine_similarity.cosine_similarity(u, v)
|
Loading…
Reference in New Issue
Block a user