Revised image embedder implementation
This commit is contained in:
parent
ba1ee5b404
commit
664d9c49e7
|
@ -92,7 +92,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||||
],
|
|
||||||
] + select({
|
] + select({
|
||||||
# TODO: Build text_classifier_graph on Windows.
|
# TODO: Build text_classifier_graph on Windows.
|
||||||
"//mediapipe:windows": [],
|
"//mediapipe:windows": [],
|
||||||
|
|
|
@ -22,8 +22,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding
|
_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding
|
||||||
_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding
|
_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding
|
||||||
_EmbeddingEntryProto = embeddings_pb2.EmbeddingEntry
|
_EmbeddingProto = embeddings_pb2.Embedding
|
||||||
_EmbeddingsProto = embeddings_pb2.Embeddings
|
|
||||||
_EmbeddingResultProto = embeddings_pb2.EmbeddingResult
|
_EmbeddingResultProto = embeddings_pb2.EmbeddingResult
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,28 +98,34 @@ class QuantizedEmbedding:
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class EmbeddingEntry:
|
class Embedding:
|
||||||
"""Floating-point or scalar-quantized embedding with an optional timestamp.
|
"""Embedding result for a given embedder head.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
embedding: The actual embedding, either floating-point or scalar-quantized.
|
embedding: The actual embedding, either floating-point or scalar-quantized.
|
||||||
timestamp_ms: The optional timestamp (in milliseconds) associated to the
|
head_index: The index of the embedder head that produced this embedding.
|
||||||
embedding entry. This is useful for time series use cases, e.g. audio
|
This is useful for multi-head models.
|
||||||
embedding.
|
head_name: The name of the embedder head, which is the corresponding tensor
|
||||||
|
metadata name (if any). This is useful for multi-head models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding: np.ndarray
|
embedding: np.ndarray
|
||||||
timestamp_ms: Optional[int] = None
|
head_index: int
|
||||||
|
head_name: str
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _EmbeddingEntryProto:
|
def to_pb2(self) -> _EmbeddingProto:
|
||||||
"""Generates a EmbeddingEntry protobuf object."""
|
"""Generates a Embedding protobuf object."""
|
||||||
|
|
||||||
if self.embedding.dtype == float:
|
if self.embedding.dtype == float:
|
||||||
return _EmbeddingEntryProto(float_embedding=self.embedding)
|
return _EmbeddingProto(float_embedding=self.embedding,
|
||||||
|
head_index=self.head_index,
|
||||||
|
head_name=self.head_name)
|
||||||
|
|
||||||
elif self.embedding.dtype == np.uint8:
|
elif self.embedding.dtype == np.uint8:
|
||||||
return _EmbeddingEntryProto(quantized_embedding=bytes(self.embedding))
|
return _EmbeddingProto(quantized_embedding=bytes(self.embedding),
|
||||||
|
head_index=self.head_index,
|
||||||
|
head_name=self.head_name)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid dtype. Only float and np.uint8 are supported.")
|
raise ValueError("Invalid dtype. Only float and np.uint8 are supported.")
|
||||||
|
@ -128,17 +133,21 @@ class EmbeddingEntry:
|
||||||
@classmethod
|
@classmethod
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def create_from_pb2(
|
def create_from_pb2(
|
||||||
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry':
|
cls, pb2_obj: _EmbeddingProto) -> 'Embedding':
|
||||||
"""Creates a `EmbeddingEntry` object from the given protobuf object."""
|
"""Creates a `Embedding` object from the given protobuf object."""
|
||||||
|
|
||||||
quantized_embedding = np.array(
|
quantized_embedding = np.array(
|
||||||
bytearray(pb2_obj.quantized_embedding.values))
|
bytearray(pb2_obj.quantized_embedding.values))
|
||||||
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
|
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
|
||||||
|
|
||||||
if len(quantized_embedding) == 0:
|
if len(quantized_embedding) == 0:
|
||||||
return EmbeddingEntry(embedding=float_embedding)
|
return Embedding(embedding=float_embedding,
|
||||||
|
head_index=pb2_obj.head_index,
|
||||||
|
head_name=pb2_obj.head_name)
|
||||||
else:
|
else:
|
||||||
return EmbeddingEntry(embedding=quantized_embedding)
|
return Embedding(embedding=quantized_embedding,
|
||||||
|
head_index=pb2_obj.head_index,
|
||||||
|
head_name=pb2_obj.head_name)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Checks if this object is equal to the given object.
|
"""Checks if this object is equal to the given object.
|
||||||
|
@ -147,55 +156,7 @@ class EmbeddingEntry:
|
||||||
Returns:
|
Returns:
|
||||||
True if the objects are equal.
|
True if the objects are equal.
|
||||||
"""
|
"""
|
||||||
if not isinstance(other, EmbeddingEntry):
|
if not isinstance(other, Embedding):
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Embeddings:
|
|
||||||
"""Embeddings for a given embedder head.
|
|
||||||
Attributes:
|
|
||||||
entries: A list of `ClassificationEntry` objects.
|
|
||||||
head_index: The index of the embedder head that produced this embedding.
|
|
||||||
This is useful for multi-head models.
|
|
||||||
head_name: The name of the embedder head, which is the corresponding tensor
|
|
||||||
metadata name (if any). This is useful for multi-head models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
entries: List[EmbeddingEntry]
|
|
||||||
head_index: int
|
|
||||||
head_name: str
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _EmbeddingsProto:
|
|
||||||
"""Generates a Embeddings protobuf object."""
|
|
||||||
return _EmbeddingsProto(
|
|
||||||
entries=[entry.to_pb2() for entry in self.entries],
|
|
||||||
head_index=self.head_index,
|
|
||||||
head_name=self.head_name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(cls, pb2_obj: _EmbeddingsProto) -> 'Embeddings':
|
|
||||||
"""Creates a `Embeddings` object from the given protobuf object."""
|
|
||||||
return Embeddings(
|
|
||||||
entries=[
|
|
||||||
EmbeddingEntry.create_from_pb2(entry)
|
|
||||||
for entry in pb2_obj.entries
|
|
||||||
],
|
|
||||||
head_index=pb2_obj.head_index,
|
|
||||||
head_name=pb2_obj.head_name)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Embeddings):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
@ -203,12 +164,12 @@ class Embeddings:
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class EmbeddingResult:
|
class EmbeddingResult:
|
||||||
"""Contains one set of results per embedder head.
|
"""Embedding results for a given embedder model.
|
||||||
Attributes:
|
Attributes:
|
||||||
embeddings: A list of `Embeddings` objects.
|
embeddings: A list of `Embedding` objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings: List[Embeddings]
|
embeddings: List[Embedding]
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _EmbeddingResultProto:
|
def to_pb2(self) -> _EmbeddingResultProto:
|
||||||
|
@ -225,7 +186,7 @@ class EmbeddingResult:
|
||||||
"""Creates a `EmbeddingResult` object from the given protobuf object."""
|
"""Creates a `EmbeddingResult` object from the given protobuf object."""
|
||||||
return EmbeddingResult(
|
return EmbeddingResult(
|
||||||
embeddings=[
|
embeddings=[
|
||||||
Embeddings.create_from_pb2(embedding)
|
Embedding.create_from_pb2(embedding)
|
||||||
for embedding in pb2_obj.embeddings
|
for embedding in pb2_obj.embeddings
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -28,3 +28,12 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "embedder_options",
|
||||||
|
srcs = ["embedder_options.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.proto import embedder_options_pb2
|
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
|
@ -1,28 +0,0 @@
|
||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Placeholder for internal Python strict library compatibility macro.
|
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "embedder_options",
|
|
||||||
srcs = ["embedder_options.py"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_py_pb2",
|
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
|
||||||
],
|
|
||||||
)
|
|
|
@ -1,13 +0,0 @@
|
||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
|
@ -23,6 +23,6 @@ py_library(
|
||||||
srcs = ["cosine_similarity.py"],
|
srcs = ["cosine_similarity.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||||
"//mediapipe/tasks/python/components/proto:embedder_options",
|
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mediapipe.tasks.python.components.containers import embeddings
|
from mediapipe.tasks.python.components.containers import embeddings
|
||||||
from mediapipe.tasks.python.components.proto import embedder_options
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
|
|
||||||
_EmbeddingEntry = embeddings.EmbeddingEntry
|
_Embedding = embeddings.Embedding
|
||||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,15 +36,15 @@ def _compute_cosine_similarity(u, v):
|
||||||
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float:
|
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
||||||
"""Utility function to compute cosine similarity between two embedding
|
"""Utility function to compute cosine similarity between two embedding.
|
||||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
May return an InvalidArgumentError if e.g. the feature vectors are of
|
||||||
of different types (quantized vs. float), have different sizes, or have an
|
different types (quantized vs. float), have different sizes, or have an
|
||||||
L2-norm of 0.
|
L2-norm of 0.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
u: An embedding entry.
|
u: An embedding.
|
||||||
v: An embedding entry.
|
v: An embedding.
|
||||||
"""
|
"""
|
||||||
if len(u.embedding) != len(v.embedding):
|
if len(u.embedding) != len(v.embedding):
|
||||||
raise ValueError(f"Cannot compute cosine similarity between embeddings "
|
raise ValueError(f"Cannot compute cosine similarity between embeddings "
|
||||||
|
|
|
@ -83,7 +83,7 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/tasks/python/components/proto:embedder_options",
|
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
|
|
|
@ -22,7 +22,7 @@ from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module
|
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||||
from mediapipe.tasks.python.components.containers import rect
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -36,8 +36,7 @@ _BaseOptions = base_options_module.BaseOptions
|
||||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||||
_FloatEmbedding = embeddings_module.FloatEmbedding
|
_FloatEmbedding = embeddings_module.FloatEmbedding
|
||||||
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
||||||
_EmbeddingEntry = embeddings_module.EmbeddingEntry
|
_Embedding = embeddings_module.Embedding
|
||||||
_Embeddings = embeddings_module.Embeddings
|
|
||||||
_EmbeddingResult = embeddings_module.EmbeddingResult
|
_EmbeddingResult = embeddings_module.EmbeddingResult
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||||
|
@ -81,12 +80,12 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
# Check embedding sizes.
|
# Check embedding sizes.
|
||||||
def _check_embedding_size(result):
|
def _check_embedding_size(result):
|
||||||
self.assertLen(result.embeddings, 1)
|
self.assertLen(result.embeddings, 1)
|
||||||
embedding_entry = result.embeddings[0].entries[0]
|
embedding_result = result.embeddings[0]
|
||||||
self.assertLen(embedding_entry.embedding, 1024)
|
self.assertLen(embedding_result.embedding, 1024)
|
||||||
if quantize:
|
if quantize:
|
||||||
self.assertEqual(embedding_entry.embedding.dtype, np.uint8)
|
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(embedding_entry.embedding.dtype, float)
|
self.assertEqual(embedding_result.embedding.dtype, float)
|
||||||
|
|
||||||
# Checks results sizes.
|
# Checks results sizes.
|
||||||
_check_embedding_size(result0)
|
_check_embedding_size(result0)
|
||||||
|
@ -94,7 +93,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
similarity = _ImageEmbedder.cosine_similarity(
|
similarity = _ImageEmbedder.cosine_similarity(
|
||||||
result0.embeddings[0].entries[0], result1.embeddings[0].entries[0])
|
result0.embeddings[0], result1.embeddings[0])
|
||||||
self.assertAlmostEqual(similarity, expected_similarity,
|
self.assertAlmostEqual(similarity, expected_similarity,
|
||||||
delta=_SIMILARITY_TOLERANCE)
|
delta=_SIMILARITY_TOLERANCE)
|
||||||
|
|
||||||
|
@ -134,7 +133,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
crop_result = embedder.embed(self.test_cropped_image)
|
crop_result = embedder.embed(self.test_cropped_image)
|
||||||
|
|
||||||
# Check embedding value.
|
# Check embedding value.
|
||||||
self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0],
|
self.assertAlmostEqual(image_result.embeddings[0].embedding[0],
|
||||||
expected_first_value)
|
expected_first_value)
|
||||||
|
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
|
|
|
@ -114,6 +114,7 @@ py_library(
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
|
|
@ -22,7 +22,7 @@ from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||||
from mediapipe.tasks.python.components.proto import embedder_options
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -32,6 +32,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
|
_ImageEmbedderResult = embeddings_module.EmbeddingResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
|
@ -40,8 +41,8 @@ _TaskInfo = task_info_module.TaskInfo
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
_TaskRunner = task_runner_module.TaskRunner
|
||||||
|
|
||||||
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out'
|
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
|
||||||
_EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT'
|
_EMBEDDINGS_TAG = 'EMBEDDINGS'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
|
@ -140,15 +141,15 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
return
|
return
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
embedding_result_proto = packet_getter.get_proto(
|
||||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||||
|
|
||||||
embedding_result = embeddings_module.EmbeddingResult([
|
embeddings = embeddings_module.EmbeddingResult([
|
||||||
embeddings_module.Embeddings.create_from_pb2(embedding)
|
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||||
for embedding in embedding_result_proto.embeddings
|
for embedding in embedding_result_proto.embeddings
|
||||||
])
|
])
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
options.result_callback(embedding_result, image,
|
options.result_callback(embeddings, image,
|
||||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
|
@ -158,8 +159,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([_EMBEDDING_RESULT_TAG,
|
':'.join([_EMBEDDINGS_TAG,
|
||||||
_EMBEDDING_RESULT_OUT_STREAM_NAME]),
|
_EMBEDDINGS_OUT_STREAM_NAME]),
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||||
],
|
],
|
||||||
task_options=options)
|
task_options=options)
|
||||||
|
@ -173,7 +174,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> embeddings_module.EmbeddingResult:
|
) -> _ImageEmbedderResult:
|
||||||
"""Performs image embedding extraction on the provided MediaPipe Image.
|
"""Performs image embedding extraction on the provided MediaPipe Image.
|
||||||
Extraction is performed on the region of interest specified by the `roi`
|
Extraction is performed on the region of interest specified by the `roi`
|
||||||
argument if provided, or on the entire image otherwise.
|
argument if provided, or on the entire image otherwise.
|
||||||
|
@ -195,18 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||||
normalized_rect.to_pb2())})
|
normalized_rect.to_pb2())})
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
embedding_result_proto = packet_getter.get_proto(
|
||||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||||
|
|
||||||
return embeddings_module.EmbeddingResult([
|
return embeddings_module.EmbeddingResult([
|
||||||
embeddings_module.Embeddings.create_from_pb2(embedding)
|
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||||
for embedding in embedding_result_proto.embeddings
|
for embedding in embedding_result_proto.embeddings
|
||||||
])
|
])
|
||||||
|
|
||||||
def embed_for_video(
|
def embed_for_video(
|
||||||
self, image: image_module.Image,
|
self, image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> embeddings_module.EmbeddingResult:
|
) -> _ImageEmbedderResult:
|
||||||
"""Performs image embedding extraction on the provided video frames.
|
"""Performs image embedding extraction on the provided video frames.
|
||||||
Extraction is performed on the region of interested specified by the `roi`
|
Extraction is performed on the region of interested specified by the `roi`
|
||||||
argument if provided, or on the entire image otherwise.
|
argument if provided, or on the entire image otherwise.
|
||||||
|
@ -237,11 +238,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
embedding_result_proto = packet_getter.get_proto(
|
||||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||||
|
|
||||||
return embeddings_module.EmbeddingResult([
|
return embeddings_module.EmbeddingResult([
|
||||||
embeddings_module.Embeddings.create_from_pb2(embedding)
|
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||||
for embedding in embedding_result_proto.embeddings
|
for embedding in embedding_result_proto.embeddings
|
||||||
])
|
])
|
||||||
|
|
||||||
def embed_async(
|
def embed_async(
|
||||||
|
@ -289,8 +290,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
})
|
})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cosine_similarity(u: embeddings_module.EmbeddingEntry,
|
def cosine_similarity(u: embeddings_module.Embedding,
|
||||||
v: embeddings_module.EmbeddingEntry) -> float:
|
v: embeddings_module.Embedding) -> float:
|
||||||
"""Utility function to compute cosine similarity [1] between two embedding
|
"""Utility function to compute cosine similarity [1] between two embedding
|
||||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||||
of different types (quantized vs. float), have different sizes, or have a
|
of different types (quantized vs. float), have different sizes, or have a
|
||||||
|
|
Loading…
Reference in New Issue
Block a user