Revised image embedder implementation
This commit is contained in:
parent
ba1ee5b404
commit
664d9c49e7
|
@ -92,7 +92,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
],
|
||||
] + select({
|
||||
# TODO: Build text_classifier_graph on Windows.
|
||||
"//mediapipe:windows": [],
|
||||
|
|
|
@ -22,8 +22,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
|||
|
||||
_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding
|
||||
_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding
|
||||
_EmbeddingEntryProto = embeddings_pb2.EmbeddingEntry
|
||||
_EmbeddingsProto = embeddings_pb2.Embeddings
|
||||
_EmbeddingProto = embeddings_pb2.Embedding
|
||||
_EmbeddingResultProto = embeddings_pb2.EmbeddingResult
|
||||
|
||||
|
||||
|
@ -99,28 +98,34 @@ class QuantizedEmbedding:
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EmbeddingEntry:
|
||||
"""Floating-point or scalar-quantized embedding with an optional timestamp.
|
||||
class Embedding:
|
||||
"""Embedding result for a given embedder head.
|
||||
|
||||
Attributes:
|
||||
embedding: The actual embedding, either floating-point or scalar-quantized.
|
||||
timestamp_ms: The optional timestamp (in milliseconds) associated to the
|
||||
embedding entry. This is useful for time series use cases, e.g. audio
|
||||
embedding.
|
||||
head_index: The index of the embedder head that produced this embedding.
|
||||
This is useful for multi-head models.
|
||||
head_name: The name of the embedder head, which is the corresponding tensor
|
||||
metadata name (if any). This is useful for multi-head models.
|
||||
"""
|
||||
|
||||
embedding: np.ndarray
|
||||
timestamp_ms: Optional[int] = None
|
||||
head_index: int
|
||||
head_name: str
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _EmbeddingEntryProto:
|
||||
"""Generates a EmbeddingEntry protobuf object."""
|
||||
def to_pb2(self) -> _EmbeddingProto:
|
||||
"""Generates a Embedding protobuf object."""
|
||||
|
||||
if self.embedding.dtype == float:
|
||||
return _EmbeddingEntryProto(float_embedding=self.embedding)
|
||||
return _EmbeddingProto(float_embedding=self.embedding,
|
||||
head_index=self.head_index,
|
||||
head_name=self.head_name)
|
||||
|
||||
elif self.embedding.dtype == np.uint8:
|
||||
return _EmbeddingEntryProto(quantized_embedding=bytes(self.embedding))
|
||||
return _EmbeddingProto(quantized_embedding=bytes(self.embedding),
|
||||
head_index=self.head_index,
|
||||
head_name=self.head_name)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid dtype. Only float and np.uint8 are supported.")
|
||||
|
@ -128,63 +133,19 @@ class EmbeddingEntry:
|
|||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry':
|
||||
"""Creates a `EmbeddingEntry` object from the given protobuf object."""
|
||||
cls, pb2_obj: _EmbeddingProto) -> 'Embedding':
|
||||
"""Creates a `Embedding` object from the given protobuf object."""
|
||||
|
||||
quantized_embedding = np.array(
|
||||
bytearray(pb2_obj.quantized_embedding.values))
|
||||
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
|
||||
|
||||
if len(quantized_embedding) == 0:
|
||||
return EmbeddingEntry(embedding=float_embedding)
|
||||
return Embedding(embedding=float_embedding,
|
||||
head_index=pb2_obj.head_index,
|
||||
head_name=pb2_obj.head_name)
|
||||
else:
|
||||
return EmbeddingEntry(embedding=quantized_embedding)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, EmbeddingEntry):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Embeddings:
|
||||
"""Embeddings for a given embedder head.
|
||||
Attributes:
|
||||
entries: A list of `ClassificationEntry` objects.
|
||||
head_index: The index of the embedder head that produced this embedding.
|
||||
This is useful for multi-head models.
|
||||
head_name: The name of the embedder head, which is the corresponding tensor
|
||||
metadata name (if any). This is useful for multi-head models.
|
||||
"""
|
||||
|
||||
entries: List[EmbeddingEntry]
|
||||
head_index: int
|
||||
head_name: str
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _EmbeddingsProto:
|
||||
"""Generates a Embeddings protobuf object."""
|
||||
return _EmbeddingsProto(
|
||||
entries=[entry.to_pb2() for entry in self.entries],
|
||||
head_index=self.head_index,
|
||||
head_name=self.head_name)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _EmbeddingsProto) -> 'Embeddings':
|
||||
"""Creates a `Embeddings` object from the given protobuf object."""
|
||||
return Embeddings(
|
||||
entries=[
|
||||
EmbeddingEntry.create_from_pb2(entry)
|
||||
for entry in pb2_obj.entries
|
||||
],
|
||||
return Embedding(embedding=quantized_embedding,
|
||||
head_index=pb2_obj.head_index,
|
||||
head_name=pb2_obj.head_name)
|
||||
|
||||
|
@ -195,7 +156,7 @@ class Embeddings:
|
|||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, Embeddings):
|
||||
if not isinstance(other, Embedding):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
@ -203,12 +164,12 @@ class Embeddings:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class EmbeddingResult:
|
||||
"""Contains one set of results per embedder head.
|
||||
"""Embedding results for a given embedder model.
|
||||
Attributes:
|
||||
embeddings: A list of `Embeddings` objects.
|
||||
embeddings: A list of `Embedding` objects.
|
||||
"""
|
||||
|
||||
embeddings: List[Embeddings]
|
||||
embeddings: List[Embedding]
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _EmbeddingResultProto:
|
||||
|
@ -225,7 +186,7 @@ class EmbeddingResult:
|
|||
"""Creates a `EmbeddingResult` object from the given protobuf object."""
|
||||
return EmbeddingResult(
|
||||
embeddings=[
|
||||
Embeddings.create_from_pb2(embedding)
|
||||
Embedding.create_from_pb2(embedding)
|
||||
for embedding in pb2_obj.embeddings
|
||||
])
|
||||
|
||||
|
|
|
@ -28,3 +28,12 @@ py_library(
|
|||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "embedder_options",
|
||||
srcs = ["embedder_options.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import dataclasses
|
||||
from typing import Any, Optional
|
||||
|
||||
from mediapipe.tasks.cc.components.proto import embedder_options_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "embedder_options",
|
||||
srcs = ["embedder_options.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/proto:embedder_options_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
|
@ -1,13 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -23,6 +23,6 @@ py_library(
|
|||
srcs = ["cosine_similarity.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/python/components/proto:embedder_options",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.components.containers import embeddings
|
||||
from mediapipe.tasks.python.components.proto import embedder_options
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
|
||||
_EmbeddingEntry = embeddings.EmbeddingEntry
|
||||
_Embedding = embeddings.Embedding
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
|
||||
|
||||
|
@ -36,15 +36,15 @@ def _compute_cosine_similarity(u, v):
|
|||
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
||||
|
||||
|
||||
def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float:
|
||||
"""Utility function to compute cosine similarity between two embedding
|
||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||
of different types (quantized vs. float), have different sizes, or have an
|
||||
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
||||
"""Utility function to compute cosine similarity between two embedding.
|
||||
May return an InvalidArgumentError if e.g. the feature vectors are of
|
||||
different types (quantized vs. float), have different sizes, or have an
|
||||
L2-norm of 0.
|
||||
|
||||
Args:
|
||||
u: An embedding entry.
|
||||
v: An embedding entry.
|
||||
u: An embedding.
|
||||
v: An embedding.
|
||||
"""
|
||||
if len(u.embedding) != len(v.embedding):
|
||||
raise ValueError(f"Cannot compute cosine similarity between embeddings "
|
||||
|
|
|
@ -83,7 +83,7 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/proto:embedder_options",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
|
|
|
@ -22,7 +22,7 @@ from absl.testing import absltest
|
|||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module
|
||||
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
|
@ -36,8 +36,7 @@ _BaseOptions = base_options_module.BaseOptions
|
|||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||
_FloatEmbedding = embeddings_module.FloatEmbedding
|
||||
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
||||
_EmbeddingEntry = embeddings_module.EmbeddingEntry
|
||||
_Embeddings = embeddings_module.Embeddings
|
||||
_Embedding = embeddings_module.Embedding
|
||||
_EmbeddingResult = embeddings_module.EmbeddingResult
|
||||
_Image = image_module.Image
|
||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||
|
@ -81,12 +80,12 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
# Check embedding sizes.
|
||||
def _check_embedding_size(result):
|
||||
self.assertLen(result.embeddings, 1)
|
||||
embedding_entry = result.embeddings[0].entries[0]
|
||||
self.assertLen(embedding_entry.embedding, 1024)
|
||||
embedding_result = result.embeddings[0]
|
||||
self.assertLen(embedding_result.embedding, 1024)
|
||||
if quantize:
|
||||
self.assertEqual(embedding_entry.embedding.dtype, np.uint8)
|
||||
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
|
||||
else:
|
||||
self.assertEqual(embedding_entry.embedding.dtype, float)
|
||||
self.assertEqual(embedding_result.embedding.dtype, float)
|
||||
|
||||
# Checks results sizes.
|
||||
_check_embedding_size(result0)
|
||||
|
@ -94,7 +93,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
|
||||
# Checks cosine similarity.
|
||||
similarity = _ImageEmbedder.cosine_similarity(
|
||||
result0.embeddings[0].entries[0], result1.embeddings[0].entries[0])
|
||||
result0.embeddings[0], result1.embeddings[0])
|
||||
self.assertAlmostEqual(similarity, expected_similarity,
|
||||
delta=_SIMILARITY_TOLERANCE)
|
||||
|
||||
|
@ -134,7 +133,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
# Check embedding value.
|
||||
self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0],
|
||||
self.assertAlmostEqual(image_result.embeddings[0].embedding[0],
|
||||
expected_first_value)
|
||||
|
||||
# Checks cosine similarity.
|
||||
|
|
|
@ -114,6 +114,7 @@ py_library(
|
|||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
|
|
|
@ -22,7 +22,7 @@ from mediapipe.python._framework_bindings import image as image_module
|
|||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.proto import embedder_options
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
|
@ -32,6 +32,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
|||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_ImageEmbedderResult = embeddings_module.EmbeddingResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
|
@ -40,8 +41,8 @@ _TaskInfo = task_info_module.TaskInfo
|
|||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskRunner = task_runner_module.TaskRunner
|
||||
|
||||
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out'
|
||||
_EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT'
|
||||
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
|
||||
_EMBEDDINGS_TAG = 'EMBEDDINGS'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
|
@ -140,15 +141,15 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
|
||||
embedding_result = embeddings_module.EmbeddingResult([
|
||||
embeddings_module.Embeddings.create_from_pb2(embedding)
|
||||
embeddings = embeddings_module.EmbeddingResult([
|
||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||
for embedding in embedding_result_proto.embeddings
|
||||
])
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(embedding_result, image,
|
||||
options.result_callback(embeddings, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
|
@ -158,8 +159,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_EMBEDDING_RESULT_TAG,
|
||||
_EMBEDDING_RESULT_OUT_STREAM_NAME]),
|
||||
':'.join([_EMBEDDINGS_TAG,
|
||||
_EMBEDDINGS_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
],
|
||||
task_options=options)
|
||||
|
@ -173,7 +174,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> embeddings_module.EmbeddingResult:
|
||||
) -> _ImageEmbedderResult:
|
||||
"""Performs image embedding extraction on the provided MediaPipe Image.
|
||||
Extraction is performed on the region of interest specified by the `roi`
|
||||
argument if provided, or on the entire image otherwise.
|
||||
|
@ -195,10 +196,10 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2())})
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
|
||||
return embeddings_module.EmbeddingResult([
|
||||
embeddings_module.Embeddings.create_from_pb2(embedding)
|
||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||
for embedding in embedding_result_proto.embeddings
|
||||
])
|
||||
|
||||
|
@ -206,7 +207,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
self, image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> embeddings_module.EmbeddingResult:
|
||||
) -> _ImageEmbedderResult:
|
||||
"""Performs image embedding extraction on the provided video frames.
|
||||
Extraction is performed on the region of interested specified by the `roi`
|
||||
argument if provided, or on the entire image otherwise.
|
||||
|
@ -237,10 +238,10 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
|
||||
return embeddings_module.EmbeddingResult([
|
||||
embeddings_module.Embeddings.create_from_pb2(embedding)
|
||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||
for embedding in embedding_result_proto.embeddings
|
||||
])
|
||||
|
||||
|
@ -289,8 +290,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
})
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity(u: embeddings_module.EmbeddingEntry,
|
||||
v: embeddings_module.EmbeddingEntry) -> float:
|
||||
def cosine_similarity(u: embeddings_module.Embedding,
|
||||
v: embeddings_module.Embedding) -> float:
|
||||
"""Utility function to compute cosine similarity [1] between two embedding
|
||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||
of different types (quantized vs. float), have different sizes, or have a
|
||||
|
|
Loading…
Reference in New Issue
Block a user