Internal change
PiperOrigin-RevId: 493427500
This commit is contained in:
parent
fca0f5806b
commit
576c6da173
|
@ -29,11 +29,11 @@ py_library(
|
|||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
|
||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
|
@ -51,11 +51,11 @@ py_library(
|
|||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
|
||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
|
|
|
@ -21,11 +21,11 @@ from mediapipe.python import packet_getter
|
|||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||
from mediapipe.tasks.python.audio.core import base_audio_task_api
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult
|
|||
_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options_module.ClassifierOptions
|
||||
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
|
@ -62,16 +62,31 @@ class AudioClassifierOptions:
|
|||
mode for running classification on the audio stream, such as from
|
||||
microphone. In this mode, the "result_callback" below must be specified
|
||||
to receive the classification results asynchronously.
|
||||
classifier_options: Options for configuring the classifier behavior, such as
|
||||
score threshold, number of results, etc.
|
||||
display_names_locale: The locale to use for display names specified through
|
||||
the TFLite Model Metadata.
|
||||
max_results: The maximum number of top-scored classification results to
|
||||
return.
|
||||
score_threshold: Overrides the ones provided in the model metadata. Results
|
||||
below this value are rejected.
|
||||
category_allowlist: Allowlist of category names. If non-empty,
|
||||
classification results whose category name is not in this set will be
|
||||
filtered out. Duplicate or unknown category names are ignored. Mutually
|
||||
exclusive with `category_denylist`.
|
||||
category_denylist: Denylist of category names. If non-empty, classification
|
||||
results whose category name is in this set will be filtered out. Duplicate
|
||||
or unknown category names are ignored. Mutually exclusive with
|
||||
`category_allowlist`.
|
||||
result_callback: The user-defined result callback for processing audio
|
||||
stream data. The result callback should only be specified when the running
|
||||
mode is set to the audio stream mode.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
|
||||
classifier_options: Optional[_ClassifierOptions] = dataclasses.field(
|
||||
default_factory=_ClassifierOptions)
|
||||
display_names_locale: Optional[str] = None
|
||||
max_results: Optional[int] = None
|
||||
score_threshold: Optional[float] = None
|
||||
category_allowlist: Optional[List[str]] = None
|
||||
category_denylist: Optional[List[str]] = None
|
||||
result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
|
@ -79,7 +94,12 @@ class AudioClassifierOptions:
|
|||
"""Generates an AudioClassifierOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
|
||||
classifier_options_proto = self.classifier_options.to_pb2()
|
||||
classifier_options_proto = _ClassifierOptionsProto(
|
||||
score_threshold=self.score_threshold,
|
||||
category_allowlist=self.category_allowlist,
|
||||
category_denylist=self.category_denylist,
|
||||
display_names_locale=self.display_names_locale,
|
||||
max_results=self.max_results)
|
||||
|
||||
return _AudioClassifierGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
|
|
|
@ -21,11 +21,11 @@ from mediapipe.python import packet_getter
|
|||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2
|
||||
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||
from mediapipe.tasks.python.audio.core import base_audio_task_api
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
|
@ -35,7 +35,7 @@ AudioEmbedderResult = embedding_result_module.EmbeddingResult
|
|||
_AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
||||
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
|
@ -63,16 +63,22 @@ class AudioEmbedderOptions:
|
|||
stream mode for running embedding extraction on the audio stream, such as
|
||||
from microphone. In this mode, the "result_callback" below must be
|
||||
specified to receive the embedding results asynchronously.
|
||||
embedder_options: Options for configuring the embedder behavior, such as
|
||||
l2_normalize and quantize.
|
||||
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
|
||||
Use this option only if the model does not already contain a native
|
||||
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
|
||||
L2 norm is thus achieved through TF Lite inference.
|
||||
quantize: Whether the returned embedding should be quantized to bytes via
|
||||
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
|
||||
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
|
||||
the l2_normalize option if this is not the case.
|
||||
result_callback: The user-defined result callback for processing audio
|
||||
stream data. The result callback should only be specified when the running
|
||||
mode is set to the audio stream mode.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
|
||||
embedder_options: Optional[_EmbedderOptions] = dataclasses.field(
|
||||
default_factory=_EmbedderOptions)
|
||||
l2_normalize: Optional[bool] = None
|
||||
quantize: Optional[bool] = None
|
||||
result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
|
@ -80,7 +86,8 @@ class AudioEmbedderOptions:
|
|||
"""Generates an AudioEmbedderOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
|
||||
embedder_options_proto = self.embedder_options.to_pb2()
|
||||
embedder_options_proto = _EmbedderOptionsProto(
|
||||
l2_normalize=self.l2_normalize, quantize=self.quantize)
|
||||
|
||||
return _AudioEmbedderGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
|
|
|
@ -28,12 +28,3 @@ py_library(
|
|||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "embedder_options",
|
||||
srcs = ["embedder_options.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,12 +15,9 @@
|
|||
"""MediaPipe Tasks Components Processors API."""
|
||||
|
||||
import mediapipe.tasks.python.components.processors.classifier_options
|
||||
import mediapipe.tasks.python.components.processors.embedder_options
|
||||
|
||||
ClassifierOptions = classifier_options.ClassifierOptions
|
||||
EmbedderOptions = embedder_options.EmbedderOptions
|
||||
|
||||
# Remove unnecessary modules to avoid duplication in API docs.
|
||||
del classifier_options
|
||||
del embedder_options
|
||||
del mediapipe
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Embedder options data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Optional
|
||||
|
||||
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EmbedderOptions:
|
||||
"""Shared options used by all embedding extraction tasks.
|
||||
|
||||
Attributes:
|
||||
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
|
||||
Use this option only if the model does not already contain a native
|
||||
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
|
||||
L2 norm is thus achieved through TF Lite inference.
|
||||
quantize: Whether the returned embedding should be quantized to bytes via
|
||||
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
|
||||
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
|
||||
the l2_normalize option if this is not the case.
|
||||
"""
|
||||
|
||||
l2_normalize: Optional[bool] = None
|
||||
quantize: Optional[bool] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _EmbedderOptionsProto:
|
||||
"""Generates a EmbedderOptions protobuf object."""
|
||||
return _EmbedderOptionsProto(
|
||||
l2_normalize=self.l2_normalize, quantize=self.quantize)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions':
|
||||
"""Creates a `EmbedderOptions` object from the given protobuf object."""
|
||||
return EmbedderOptions(
|
||||
l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, EmbedderOptions):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -23,8 +23,5 @@ licenses(["notice"])
|
|||
py_library(
|
||||
name = "cosine_similarity",
|
||||
srcs = ["cosine_similarity.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
],
|
||||
deps = ["//mediapipe/tasks/python/components/containers:embedding_result"],
|
||||
)
|
||||
|
|
|
@ -16,10 +16,8 @@
|
|||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.components.containers import embedding_result
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
|
||||
_Embedding = embedding_result.Embedding
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
|
||||
|
||||
def _compute_cosine_similarity(u, v):
|
||||
|
|
|
@ -30,7 +30,6 @@ py_test(
|
|||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
|
@ -48,7 +47,6 @@ py_test(
|
|||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
|
|
|
@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier
|
|||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions
|
|||
_AudioClassifierResult = classification_result_module.ClassificationResult
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||
|
||||
_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
|
||||
|
@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
max_results=1))) as classifier:
|
||||
max_results=1)) as classifier:
|
||||
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
|
@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
score_threshold=0.9))) as classifier:
|
||||
score_threshold=0.9)) as classifier:
|
||||
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
|
@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
category_allowlist=['Speech']))) as classifier:
|
||||
category_allowlist=['Speech'])) as classifier:
|
||||
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
|
@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
r'exclusive options.'):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
category_allowlist=['foo'], category_denylist=['bar']))
|
||||
category_allowlist=['foo'],
|
||||
category_denylist=['bar'])
|
||||
with _AudioClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
|
@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(
|
||||
model_asset_path=self.two_heads_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
max_results=1))) as classifier:
|
||||
max_results=1)) as classifier:
|
||||
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
|
@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||
classifier_options=_ClassifierOptions(max_results=1),
|
||||
max_results=1,
|
||||
result_callback=save_result)
|
||||
classifier = _AudioClassifier.create_from_options(options)
|
||||
audio_data_list = self._read_wav_file_as_stream(audio_file)
|
||||
|
|
|
@ -26,7 +26,6 @@ from scipy.io import wavfile
|
|||
from mediapipe.tasks.python.audio import audio_embedder
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
@ -35,7 +34,6 @@ _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
|
|||
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||
|
||||
_YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite'
|
||||
|
@ -172,9 +170,7 @@ class AudioEmbedderTest(parameterized.TestCase):
|
|||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _AudioEmbedderOptions(
|
||||
base_options=base_options,
|
||||
embedder_options=_EmbedderOptions(
|
||||
l2_normalize=l2_normalize, quantize=quantize))
|
||||
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
|
||||
|
||||
with _AudioEmbedder.create_from_options(options) as embedder:
|
||||
embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0))
|
||||
|
@ -291,8 +287,8 @@ class AudioEmbedderTest(parameterized.TestCase):
|
|||
options = _AudioEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||
embedder_options=_EmbedderOptions(
|
||||
l2_normalize=l2_normalize, quantize=quantize),
|
||||
l2_normalize=l2_normalize,
|
||||
quantize=quantize,
|
||||
result_callback=save_result)
|
||||
|
||||
with _AudioEmbedder.create_from_options(options) as embedder:
|
||||
|
|
|
@ -28,7 +28,6 @@ py_test(
|
|||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/text:text_classifier",
|
||||
|
@ -44,7 +43,6 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/text:text_embedder",
|
||||
|
|
|
@ -21,14 +21,12 @@ from absl.testing import parameterized
|
|||
|
||||
from mediapipe.tasks.python.components.containers import category
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.text import text_classifier
|
||||
|
||||
TextClassifierResult = classification_result_module.ClassificationResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_Category = category.Category
|
||||
_Classifications = classification_result_module.Classifications
|
||||
_TextClassifier = text_classifier.TextClassifier
|
||||
|
|
|
@ -21,13 +21,11 @@ from absl.testing import parameterized
|
|||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.text import text_embedder
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||
_Embedding = embedding_result_module.Embedding
|
||||
_TextEmbedder = text_embedder.TextEmbedder
|
||||
_TextEmbedderOptions = text_embedder.TextEmbedderOptions
|
||||
|
@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase):
|
|||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
embedder_options = _EmbedderOptions(
|
||||
l2_normalize=l2_normalize, quantize=quantize)
|
||||
options = _TextEmbedderOptions(
|
||||
base_options=base_options, embedder_options=embedder_options)
|
||||
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
|
||||
embedder = _TextEmbedder.create_from_options(options)
|
||||
|
||||
# Extracts both embeddings.
|
||||
|
@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase):
|
|||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
embedder_options = _EmbedderOptions(
|
||||
l2_normalize=l2_normalize, quantize=quantize)
|
||||
options = _TextEmbedderOptions(
|
||||
base_options=base_options, embedder_options=embedder_options)
|
||||
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
|
||||
with _TextEmbedder.create_from_options(options) as embedder:
|
||||
# Extracts both embeddings.
|
||||
positive_text0 = "it's a charming and often affecting journey"
|
||||
|
|
|
@ -49,7 +49,6 @@ py_test(
|
|||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:image_classifier",
|
||||
|
@ -69,7 +68,6 @@ py_test(
|
|||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:image_embedder",
|
||||
|
|
|
@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image
|
|||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_classifier
|
||||
|
@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
|||
ImageClassifierResult = classification_result_module.ClassificationResult
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_Category = category_module.Category
|
||||
_Classifications = classification_result_module.Classifications
|
||||
_Image = image.Image
|
||||
|
@ -171,9 +169,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(max_results=max_results)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, classifier_options=custom_classifier_options)
|
||||
base_options=base_options, max_results=max_results)
|
||||
classifier = _ImageClassifier.create_from_options(options)
|
||||
|
||||
# Performs image classification on the input.
|
||||
|
@ -200,9 +197,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(max_results=max_results)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, classifier_options=custom_classifier_options)
|
||||
base_options=base_options, max_results=max_results)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -212,9 +208,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
def test_classify_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, classifier_options=custom_classifier_options)
|
||||
options = _ImageClassifierOptions(base_options=base_options, max_results=1)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
|
@ -230,11 +224,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
_generate_soccer_ball_results().to_pb2())
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -249,11 +241,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
f'{classification}')
|
||||
|
||||
def test_max_results_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -263,11 +253,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
||||
|
||||
def test_allow_list_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
category_allowlist=_ALLOW_LIST)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
category_allowlist=_ALLOW_LIST)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -280,10 +268,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
f'Label {label} found but not in label allow list')
|
||||
|
||||
def test_deny_list_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
category_denylist=_DENY_LIST)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -301,19 +288,17 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
ValueError,
|
||||
r'`category_allowlist` and `category_denylist` are mutually '
|
||||
r'exclusive options.'):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
category_allowlist=['foo'], category_denylist=['bar'])
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
category_allowlist=['foo'],
|
||||
category_denylist=['bar'])
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_empty_classification_outputs(self):
|
||||
custom_classifier_options = _ClassifierOptions(score_threshold=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
score_threshold=1)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -386,11 +371,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=4)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
classifier_options=custom_classifier_options)
|
||||
max_results=4)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
|
@ -399,11 +383,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
_generate_burger_results().to_pb2())
|
||||
|
||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
classifier_options=custom_classifier_options)
|
||||
max_results=1)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
|
@ -439,11 +422,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=4)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
classifier_options=custom_classifier_options,
|
||||
max_results=4,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(self.test_image, 100)
|
||||
|
@ -466,12 +448,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
max_results=4, score_threshold=threshold)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
classifier_options=custom_classifier_options,
|
||||
max_results=4,
|
||||
score_threshold=threshold,
|
||||
result_callback=check_result)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
|
@ -496,11 +477,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
classifier_options=custom_classifier_options,
|
||||
max_results=1,
|
||||
result_callback=check_result)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
|
|
|
@ -24,7 +24,6 @@ import numpy as np
|
|||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_embedder
|
||||
|
@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
|||
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||
_Embedding = embedding_result_module.Embedding
|
||||
_Image = image_module.Image
|
||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||
|
@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
embedder_options = _EmbedderOptions(
|
||||
l2_normalize=l2_normalize, quantize=quantize)
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=base_options, embedder_options=embedder_options)
|
||||
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
|
||||
embedder = _ImageEmbedder.create_from_options(options)
|
||||
|
||||
image_processing_options = None
|
||||
|
@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
embedder_options = _EmbedderOptions(
|
||||
l2_normalize=l2_normalize, quantize=quantize)
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=base_options, embedder_options=embedder_options)
|
||||
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
|
||||
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
# Extracts both embeddings.
|
||||
|
|
|
@ -28,9 +28,9 @@ py_library(
|
|||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
|
@ -47,9 +47,9 @@ py_library(
|
|||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
|
|
|
@ -14,14 +14,14 @@
|
|||
"""MediaPipe text classifier task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
|
||||
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
@ -30,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api
|
|||
TextClassifierResult = classification_result_module.ClassificationResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||
|
@ -46,17 +46,38 @@ class TextClassifierOptions:
|
|||
|
||||
Attributes:
|
||||
base_options: Base options for the text classifier task.
|
||||
classifier_options: Options for the text classification task.
|
||||
display_names_locale: The locale to use for display names specified through
|
||||
the TFLite Model Metadata.
|
||||
max_results: The maximum number of top-scored classification results to
|
||||
return.
|
||||
score_threshold: Overrides the ones provided in the model metadata. Results
|
||||
below this value are rejected.
|
||||
category_allowlist: Allowlist of category names. If non-empty,
|
||||
classification results whose category name is not in this set will be
|
||||
filtered out. Duplicate or unknown category names are ignored. Mutually
|
||||
exclusive with `category_denylist`.
|
||||
category_denylist: Denylist of category names. If non-empty, classification
|
||||
results whose category name is in this set will be filtered out. Duplicate
|
||||
or unknown category names are ignored. Mutually exclusive with
|
||||
`category_allowlist`.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
classifier_options: Optional[_ClassifierOptions] = dataclasses.field(
|
||||
default_factory=_ClassifierOptions)
|
||||
display_names_locale: Optional[str] = None
|
||||
max_results: Optional[int] = None
|
||||
score_threshold: Optional[float] = None
|
||||
category_allowlist: Optional[List[str]] = None
|
||||
category_denylist: Optional[List[str]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
|
||||
"""Generates an TextClassifierOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
classifier_options_proto = self.classifier_options.to_pb2()
|
||||
classifier_options_proto = _ClassifierOptionsProto(
|
||||
score_threshold=self.score_threshold,
|
||||
category_allowlist=self.category_allowlist,
|
||||
category_denylist=self.category_denylist,
|
||||
display_names_locale=self.display_names_locale,
|
||||
max_results=self.max_results)
|
||||
|
||||
return _TextClassifierGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
|
|
|
@ -19,9 +19,9 @@ from typing import Optional
|
|||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
|
||||
from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
|
@ -31,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api
|
|||
TextEmbedderResult = embedding_result_module.EmbeddingResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
|
||||
|
@ -47,17 +47,25 @@ class TextEmbedderOptions:
|
|||
|
||||
Attributes:
|
||||
base_options: Base options for the text embedder task.
|
||||
embedder_options: Options for the text embedder task.
|
||||
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
|
||||
Use this option only if the model does not already contain a native
|
||||
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
|
||||
L2 norm is thus achieved through TF Lite inference.
|
||||
quantize: Whether the returned embedding should be quantized to bytes via
|
||||
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
|
||||
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
|
||||
the l2_normalize option if this is not the case.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
embedder_options: Optional[_EmbedderOptions] = dataclasses.field(
|
||||
default_factory=_EmbedderOptions)
|
||||
l2_normalize: Optional[bool] = None
|
||||
quantize: Optional[bool] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _TextEmbedderGraphOptionsProto:
|
||||
"""Generates an TextEmbedderOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
embedder_options_proto = self.embedder_options.to_pb2()
|
||||
embedder_options_proto = _EmbedderOptionsProto(
|
||||
l2_normalize=self.l2_normalize, quantize=self.quantize)
|
||||
|
||||
return _TextEmbedderGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
|
|
|
@ -47,10 +47,10 @@ py_library(
|
|||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
|
@ -89,9 +89,9 @@ py_library(
|
|||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
|
|
|
@ -14,17 +14,17 @@
|
|||
"""MediaPipe image classifier task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, Mapping, Optional
|
||||
from typing import Callable, Mapping, Optional, List
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
|
||||
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult
|
|||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
@ -63,15 +63,31 @@ class ImageClassifierOptions:
|
|||
objects on single image inputs. 2) The video mode for classifying objects
|
||||
on the decoded frames of a video. 3) The live stream mode for classifying
|
||||
objects on a live stream of input data, such as from camera.
|
||||
classifier_options: Options for the image classification task.
|
||||
display_names_locale: The locale to use for display names specified through
|
||||
the TFLite Model Metadata.
|
||||
max_results: The maximum number of top-scored classification results to
|
||||
return.
|
||||
score_threshold: Overrides the ones provided in the model metadata. Results
|
||||
below this value are rejected.
|
||||
category_allowlist: Allowlist of category names. If non-empty,
|
||||
classification results whose category name is not in this set will be
|
||||
filtered out. Duplicate or unknown category names are ignored. Mutually
|
||||
exclusive with `category_denylist`.
|
||||
category_denylist: Denylist of category names. If non-empty, classification
|
||||
results whose category name is in this set will be filtered out. Duplicate
|
||||
or unknown category names are ignored. Mutually exclusive with
|
||||
`category_allowlist`.
|
||||
result_callback: The user-defined result callback for processing live stream
|
||||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
classifier_options: Optional[_ClassifierOptions] = dataclasses.field(
|
||||
default_factory=_ClassifierOptions)
|
||||
display_names_locale: Optional[str] = None
|
||||
max_results: Optional[int] = None
|
||||
score_threshold: Optional[float] = None
|
||||
category_allowlist: Optional[List[str]] = None
|
||||
category_denylist: Optional[List[str]] = None
|
||||
result_callback: Optional[Callable[
|
||||
[ImageClassifierResult, image_module.Image, int], None]] = None
|
||||
|
||||
|
@ -80,7 +96,12 @@ class ImageClassifierOptions:
|
|||
"""Generates an ImageClassifierOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
classifier_options_proto = self.classifier_options.to_pb2()
|
||||
classifier_options_proto = _ClassifierOptionsProto(
|
||||
score_threshold=self.score_threshold,
|
||||
category_allowlist=self.category_allowlist,
|
||||
category_denylist=self.category_denylist,
|
||||
display_names_locale=self.display_names_locale,
|
||||
max_results=self.max_results)
|
||||
|
||||
return _ImageClassifierGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
|
|
|
@ -21,9 +21,9 @@ from mediapipe.python import packet_getter
|
|||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
|
||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
|
@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
|||
ImageEmbedderResult = embedding_result_module.EmbeddingResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
@ -62,15 +62,22 @@ class ImageEmbedderOptions:
|
|||
image on single image inputs. 2) The video mode for embedding image on the
|
||||
decoded frames of a video. 3) The live stream mode for embedding image on
|
||||
a live stream of input data, such as from camera.
|
||||
embedder_options: Options for the image embedder task.
|
||||
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
|
||||
Use this option only if the model does not already contain a native
|
||||
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
|
||||
L2 norm is thus achieved through TF Lite inference.
|
||||
quantize: Whether the returned embedding should be quantized to bytes via
|
||||
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
|
||||
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
|
||||
the l2_normalize option if this is not the case.
|
||||
result_callback: The user-defined result callback for processing live stream
|
||||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
embedder_options: Optional[_EmbedderOptions] = dataclasses.field(
|
||||
default_factory=_EmbedderOptions)
|
||||
l2_normalize: Optional[bool] = None
|
||||
quantize: Optional[bool] = None
|
||||
result_callback: Optional[Callable[
|
||||
[ImageEmbedderResult, image_module.Image, int], None]] = None
|
||||
|
||||
|
@ -79,7 +86,8 @@ class ImageEmbedderOptions:
|
|||
"""Generates an ImageEmbedderOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
embedder_options_proto = self.embedder_options.to_pb2()
|
||||
embedder_options_proto = _EmbedderOptionsProto(
|
||||
l2_normalize=self.l2_normalize, quantize=self.quantize)
|
||||
|
||||
return _ImageEmbedderGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
|
|
Loading…
Reference in New Issue
Block a user