Internal change

PiperOrigin-RevId: 493427500
This commit is contained in:
MediaPipe Team 2022-12-06 15:22:03 -08:00 committed by Copybara-Service
parent fca0f5806b
commit 576c6da173
23 changed files with 162 additions and 212 deletions

View File

@ -29,11 +29,11 @@ py_library(
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/audio/core:base_audio_task_api",
"//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
@ -51,11 +51,11 @@ py_library(
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/audio/core:base_audio_task_api",
"//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -21,11 +21,11 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2 from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.audio.core import base_audio_task_api
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult
_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions _AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions
_AudioData = audio_data_module.AudioData _AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options_module.ClassifierOptions _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_RunningMode = running_mode_module.AudioTaskRunningMode _RunningMode = running_mode_module.AudioTaskRunningMode
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -62,16 +62,31 @@ class AudioClassifierOptions:
mode for running classification on the audio stream, such as from mode for running classification on the audio stream, such as from
microphone. In this mode, the "result_callback" below must be specified microphone. In this mode, the "result_callback" below must be specified
to receive the classification results asynchronously. to receive the classification results asynchronously.
classifier_options: Options for configuring the classifier behavior, such as display_names_locale: The locale to use for display names specified through
score threshold, number of results, etc. the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
result_callback: The user-defined result callback for processing audio result_callback: The user-defined result callback for processing audio
stream data. The result callback should only be specified when the running stream data. The result callback should only be specified when the running
mode is set to the audio stream mode. mode is set to the audio stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
classifier_options: Optional[_ClassifierOptions] = dataclasses.field( display_names_locale: Optional[str] = None
default_factory=_ClassifierOptions) max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -79,7 +94,12 @@ class AudioClassifierOptions:
"""Generates an AudioClassifierOptions protobuf object.""" """Generates an AudioClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
classifier_options_proto = self.classifier_options.to_pb2() classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
return _AudioClassifierGraphOptionsProto( return _AudioClassifierGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,

View File

@ -21,11 +21,11 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.audio.core import base_audio_task_api
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
@ -35,7 +35,7 @@ AudioEmbedderResult = embedding_result_module.EmbeddingResult
_AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions _AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions
_AudioData = audio_data_module.AudioData _AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions _EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_RunningMode = running_mode_module.AudioTaskRunningMode _RunningMode = running_mode_module.AudioTaskRunningMode
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -63,16 +63,22 @@ class AudioEmbedderOptions:
stream mode for running embedding extraction on the audio stream, such as stream mode for running embedding extraction on the audio stream, such as
from microphone. In this mode, the "result_callback" below must be from microphone. In this mode, the "result_callback" below must be
specified to receive the embedding results asynchronously. specified to receive the embedding results asynchronously.
embedder_options: Options for configuring the embedder behavior, such as l2_normalize: Whether to normalize the returned feature vector with L2 norm.
l2_normalize and quantize. Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
result_callback: The user-defined result callback for processing audio result_callback: The user-defined result callback for processing audio
stream data. The result callback should only be specified when the running stream data. The result callback should only be specified when the running
mode is set to the audio stream mode. mode is set to the audio stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
embedder_options: Optional[_EmbedderOptions] = dataclasses.field( l2_normalize: Optional[bool] = None
default_factory=_EmbedderOptions) quantize: Optional[bool] = None
result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -80,7 +86,8 @@ class AudioEmbedderOptions:
"""Generates an AudioEmbedderOptions protobuf object.""" """Generates an AudioEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
embedder_options_proto = self.embedder_options.to_pb2() embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
return _AudioEmbedderGraphOptionsProto( return _AudioEmbedderGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,

View File

@ -28,12 +28,3 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
], ],
) )
py_library(
name = "embedder_options",
srcs = ["embedder_options.py"],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -15,12 +15,9 @@
"""MediaPipe Tasks Components Processors API.""" """MediaPipe Tasks Components Processors API."""
import mediapipe.tasks.python.components.processors.classifier_options import mediapipe.tasks.python.components.processors.classifier_options
import mediapipe.tasks.python.components.processors.embedder_options
ClassifierOptions = classifier_options.ClassifierOptions ClassifierOptions = classifier_options.ClassifierOptions
EmbedderOptions = embedder_options.EmbedderOptions
# Remove unnecessary modules to avoid duplication in API docs. # Remove unnecessary modules to avoid duplication in API docs.
del classifier_options del classifier_options
del embedder_options
del mediapipe del mediapipe

View File

@ -1,68 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedder options data class."""
import dataclasses
from typing import Any, Optional
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
@dataclasses.dataclass
class EmbedderOptions:
"""Shared options used by all embedding extraction tasks.
Attributes:
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
"""
l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbedderOptionsProto:
"""Generates a EmbedderOptions protobuf object."""
return _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions':
"""Creates a `EmbedderOptions` object from the given protobuf object."""
return EmbedderOptions(
l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, EmbedderOptions):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -23,8 +23,5 @@ licenses(["notice"])
py_library( py_library(
name = "cosine_similarity", name = "cosine_similarity",
srcs = ["cosine_similarity.py"], srcs = ["cosine_similarity.py"],
deps = [ deps = ["//mediapipe/tasks/python/components/containers:embedding_result"],
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
],
) )

View File

@ -16,10 +16,8 @@
import numpy as np import numpy as np
from mediapipe.tasks.python.components.containers import embedding_result from mediapipe.tasks.python.components.containers import embedding_result
from mediapipe.tasks.python.components.processors import embedder_options
_Embedding = embedding_result.Embedding _Embedding = embedding_result.Embedding
_EmbedderOptions = embedder_options.EmbedderOptions
def _compute_cosine_similarity(u, v): def _compute_cosine_similarity(u, v):

View File

@ -30,7 +30,6 @@ py_test(
"//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
@ -48,7 +47,6 @@ py_test(
"//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],

View File

@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier
from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions
_AudioClassifierResult = classification_result_module.ClassificationResult _AudioClassifierResult = classification_result_module.ClassificationResult
_AudioData = audio_data_module.AudioData _AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' _YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase):
with _AudioClassifier.create_from_options( with _AudioClassifier.create_from_options(
_AudioClassifierOptions( _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions( max_results=1)) as classifier:
max_results=1))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify( classification_result_list = classifier.classify(
self._read_wav_file(audio_file)) self._read_wav_file(audio_file))
@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase):
with _AudioClassifier.create_from_options( with _AudioClassifier.create_from_options(
_AudioClassifierOptions( _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions( score_threshold=0.9)) as classifier:
score_threshold=0.9))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify( classification_result_list = classifier.classify(
self._read_wav_file(audio_file)) self._read_wav_file(audio_file))
@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase):
with _AudioClassifier.create_from_options( with _AudioClassifier.create_from_options(
_AudioClassifierOptions( _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions( category_allowlist=['Speech'])) as classifier:
category_allowlist=['Speech']))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify( classification_result_list = classifier.classify(
self._read_wav_file(audio_file)) self._read_wav_file(audio_file))
@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase):
r'exclusive options.'): r'exclusive options.'):
options = _AudioClassifierOptions( options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions( category_allowlist=['foo'],
category_allowlist=['foo'], category_denylist=['bar'])) category_denylist=['bar'])
with _AudioClassifier.create_from_options(options) as unused_classifier: with _AudioClassifier.create_from_options(options) as unused_classifier:
pass pass
@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase):
_AudioClassifierOptions( _AudioClassifierOptions(
base_options=_BaseOptions( base_options=_BaseOptions(
model_asset_path=self.two_heads_model_path), model_asset_path=self.two_heads_model_path),
classifier_options=_ClassifierOptions( max_results=1)) as classifier:
max_results=1))) as classifier:
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
classification_result_list = classifier.classify( classification_result_list = classifier.classify(
self._read_wav_file(audio_file)) self._read_wav_file(audio_file))
@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase):
options = _AudioClassifierOptions( options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM, running_mode=_RUNNING_MODE.AUDIO_STREAM,
classifier_options=_ClassifierOptions(max_results=1), max_results=1,
result_callback=save_result) result_callback=save_result)
classifier = _AudioClassifier.create_from_options(options) classifier = _AudioClassifier.create_from_options(options)
audio_data_list = self._read_wav_file_as_stream(audio_file) audio_data_list = self._read_wav_file_as_stream(audio_file)

View File

@ -26,7 +26,6 @@ from scipy.io import wavfile
from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio import audio_embedder
from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -35,7 +34,6 @@ _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioEmbedderResult = audio_embedder.AudioEmbedderResult
_AudioData = audio_data_module.AudioData _AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options.EmbedderOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
_YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' _YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite'
@ -172,9 +170,7 @@ class AudioEmbedderTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _AudioEmbedderOptions( options = _AudioEmbedderOptions(
base_options=base_options, base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
embedder_options=_EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize))
with _AudioEmbedder.create_from_options(options) as embedder: with _AudioEmbedder.create_from_options(options) as embedder:
embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0))
@ -291,8 +287,8 @@ class AudioEmbedderTest(parameterized.TestCase):
options = _AudioEmbedderOptions( options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM, running_mode=_RUNNING_MODE.AUDIO_STREAM,
embedder_options=_EmbedderOptions( l2_normalize=l2_normalize,
l2_normalize=l2_normalize, quantize=quantize), quantize=quantize,
result_callback=save_result) result_callback=save_result)
with _AudioEmbedder.create_from_options(options) as embedder: with _AudioEmbedder.create_from_options(options) as embedder:

View File

@ -28,7 +28,6 @@ py_test(
deps = [ deps = [
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:text_classifier", "//mediapipe/tasks/python/text:text_classifier",
@ -44,7 +43,6 @@ py_test(
], ],
deps = [ deps = [
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:text_embedder", "//mediapipe/tasks/python/text:text_embedder",

View File

@ -21,14 +21,12 @@ from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_classifier from mediapipe.tasks.python.text import text_classifier
TextClassifierResult = classification_result_module.ClassificationResult TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category _Category = category.Category
_Classifications = classification_result_module.Classifications _Classifications = classification_result_module.Classifications
_TextClassifier = text_classifier.TextClassifier _TextClassifier = text_classifier.TextClassifier

View File

@ -21,13 +21,11 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_embedder from mediapipe.tasks.python.text import text_embedder
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_Embedding = embedding_result_module.Embedding _Embedding = embedding_result_module.Embedding
_TextEmbedder = text_embedder.TextEmbedder _TextEmbedder = text_embedder.TextEmbedder
_TextEmbedderOptions = text_embedder.TextEmbedderOptions _TextEmbedderOptions = text_embedder.TextEmbedderOptions
@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _TextEmbedderOptions( options = _TextEmbedderOptions(
base_options=base_options, embedder_options=embedder_options) base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
embedder = _TextEmbedder.create_from_options(options) embedder = _TextEmbedder.create_from_options(options)
# Extracts both embeddings. # Extracts both embeddings.
@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _TextEmbedderOptions( options = _TextEmbedderOptions(
base_options=base_options, embedder_options=embedder_options) base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
with _TextEmbedder.create_from_options(options) as embedder: with _TextEmbedder.create_from_options(options) as embedder:
# Extracts both embeddings. # Extracts both embeddings.
positive_text0 = "it's a charming and often affecting journey" positive_text0 = "it's a charming and often affecting journey"

View File

@ -49,7 +49,6 @@ py_test(
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision:image_classifier",
@ -69,7 +68,6 @@ py_test(
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_embedder", "//mediapipe/tasks/python/vision:image_embedder",

View File

@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision import image_classifier
@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult ImageClassifierResult = classification_result_module.ClassificationResult
_Rect = rect.Rect _Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category_module.Category _Category = category_module.Category
_Classifications = classification_result_module.Classifications _Classifications = classification_result_module.Classifications
_Image = image.Image _Image = image.Image
@ -171,9 +169,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options) base_options=base_options, max_results=max_results)
classifier = _ImageClassifier.create_from_options(options) classifier = _ImageClassifier.create_from_options(options)
# Performs image classification on the input. # Performs image classification on the input.
@ -200,9 +197,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options) base_options=base_options, max_results=max_results)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -212,9 +208,7 @@ class ImageClassifierTest(parameterized.TestCase):
def test_classify_succeeds_with_region_of_interest(self): def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions(base_options=base_options, max_results=1)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
@ -230,11 +224,9 @@ class ImageClassifierTest(parameterized.TestCase):
_generate_soccer_ball_results().to_pb2()) _generate_soccer_ball_results().to_pb2())
def test_score_threshold_option(self): def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions(
score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options) score_threshold=_SCORE_THRESHOLD)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -249,11 +241,9 @@ class ImageClassifierTest(parameterized.TestCase):
f'{classification}') f'{classification}')
def test_max_results_option(self): def test_max_results_option(self):
custom_classifier_options = _ClassifierOptions(
score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options) score_threshold=_SCORE_THRESHOLD)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -263,11 +253,9 @@ class ImageClassifierTest(parameterized.TestCase):
len(categories), _MAX_RESULTS, 'Too many results returned.') len(categories), _MAX_RESULTS, 'Too many results returned.')
def test_allow_list_option(self): def test_allow_list_option(self):
custom_classifier_options = _ClassifierOptions(
category_allowlist=_ALLOW_LIST)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options) category_allowlist=_ALLOW_LIST)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -280,10 +268,9 @@ class ImageClassifierTest(parameterized.TestCase):
f'Label {label} found but not in label allow list') f'Label {label} found but not in label allow list')
def test_deny_list_option(self): def test_deny_list_option(self):
custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options) category_denylist=_DENY_LIST)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -301,19 +288,17 @@ class ImageClassifierTest(parameterized.TestCase):
ValueError, ValueError,
r'`category_allowlist` and `category_denylist` are mutually ' r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'): r'exclusive options.'):
custom_classifier_options = _ClassifierOptions(
category_allowlist=['foo'], category_denylist=['bar'])
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options) category_allowlist=['foo'],
category_denylist=['bar'])
with _ImageClassifier.create_from_options(options) as unused_classifier: with _ImageClassifier.create_from_options(options) as unused_classifier:
pass pass
def test_empty_classification_outputs(self): def test_empty_classification_outputs(self):
custom_classifier_options = _ClassifierOptions(score_threshold=1)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options) score_threshold=1)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -386,11 +371,10 @@ class ImageClassifierTest(parameterized.TestCase):
classifier.classify_for_video(self.test_image, 0) classifier.classify_for_video(self.test_image, 0)
def test_classify_for_video(self): def test_classify_for_video(self):
custom_classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO, running_mode=_RUNNING_MODE.VIDEO,
classifier_options=custom_classifier_options) max_results=4)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
@ -399,11 +383,10 @@ class ImageClassifierTest(parameterized.TestCase):
_generate_burger_results().to_pb2()) _generate_burger_results().to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self): def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO, running_mode=_RUNNING_MODE.VIDEO,
classifier_options=custom_classifier_options) max_results=1)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
@ -439,11 +422,10 @@ class ImageClassifierTest(parameterized.TestCase):
classifier.classify_for_video(self.test_image, 0) classifier.classify_for_video(self.test_image, 0)
def test_classify_async_calls_with_illegal_timestamp(self): def test_classify_async_calls_with_illegal_timestamp(self):
custom_classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options, max_results=4,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 100) classifier.classify_async(self.test_image, 100)
@ -466,12 +448,11 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms self.observed_timestamp_ms = timestamp_ms
custom_classifier_options = _ClassifierOptions(
max_results=4, score_threshold=threshold)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options, max_results=4,
score_threshold=threshold,
result_callback=check_result) result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
@ -496,11 +477,10 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms self.observed_timestamp_ms = timestamp_ms
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options, max_results=1,
result_callback=check_result) result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):

View File

@ -24,7 +24,6 @@ import numpy as np
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_embedder from mediapipe.tasks.python.vision import image_embedder
@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
_Rect = rect.Rect _Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_Embedding = embedding_result_module.Embedding _Embedding = embedding_result_module.Embedding
_Image = image_module.Image _Image = image_module.Image
_ImageEmbedder = image_embedder.ImageEmbedder _ImageEmbedder = image_embedder.ImageEmbedder
@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _ImageEmbedderOptions( options = _ImageEmbedderOptions(
base_options=base_options, embedder_options=embedder_options) base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
embedder = _ImageEmbedder.create_from_options(options) embedder = _ImageEmbedder.create_from_options(options)
image_processing_options = None image_processing_options = None
@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _ImageEmbedderOptions( options = _ImageEmbedderOptions(
base_options=base_options, embedder_options=embedder_options) base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
with _ImageEmbedder.create_from_options(options) as embedder: with _ImageEmbedder.create_from_options(options) as embedder:
# Extracts both embeddings. # Extracts both embeddings.

View File

@ -28,9 +28,9 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
@ -47,9 +47,9 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -14,14 +14,14 @@
"""MediaPipe text classifier task.""" """MediaPipe text classifier task."""
import dataclasses import dataclasses
from typing import Optional from typing import Optional, List
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -30,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api
TextClassifierResult = classification_result_module.ClassificationResult TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out' _CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
@ -46,17 +46,38 @@ class TextClassifierOptions:
Attributes: Attributes:
base_options: Base options for the text classifier task. base_options: Base options for the text classifier task.
classifier_options: Options for the text classification task. display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
""" """
base_options: _BaseOptions base_options: _BaseOptions
classifier_options: Optional[_ClassifierOptions] = dataclasses.field( display_names_locale: Optional[str] = None
default_factory=_ClassifierOptions) max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextClassifierGraphOptionsProto: def to_pb2(self) -> _TextClassifierGraphOptionsProto:
"""Generates an TextClassifierOptions protobuf object.""" """Generates an TextClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
classifier_options_proto = self.classifier_options.to_pb2() classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
return _TextClassifierGraphOptionsProto( return _TextClassifierGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,

View File

@ -19,9 +19,9 @@ from typing import Optional
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
@ -31,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api
TextEmbedderResult = embedding_result_module.EmbeddingResult TextEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions _TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
@ -47,17 +47,25 @@ class TextEmbedderOptions:
Attributes: Attributes:
base_options: Base options for the text embedder task. base_options: Base options for the text embedder task.
embedder_options: Options for the text embedder task. l2_normalize: Whether to normalize the returned feature vector with L2 norm.
Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
""" """
base_options: _BaseOptions base_options: _BaseOptions
embedder_options: Optional[_EmbedderOptions] = dataclasses.field( l2_normalize: Optional[bool] = None
default_factory=_EmbedderOptions) quantize: Optional[bool] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextEmbedderGraphOptionsProto: def to_pb2(self) -> _TextEmbedderGraphOptionsProto:
"""Generates an TextEmbedderOptions protobuf object.""" """Generates an TextEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
embedder_options_proto = self.embedder_options.to_pb2() embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
return _TextEmbedderGraphOptionsProto( return _TextEmbedderGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,

View File

@ -47,10 +47,10 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
@ -89,9 +89,9 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -14,17 +14,17 @@
"""MediaPipe image classifier task.""" """MediaPipe image classifier task."""
import dataclasses import dataclasses
from typing import Callable, Mapping, Optional from typing import Callable, Mapping, Optional, List
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -63,15 +63,31 @@ class ImageClassifierOptions:
objects on single image inputs. 2) The video mode for classifying objects objects on single image inputs. 2) The video mode for classifying objects
on the decoded frames of a video. 3) The live stream mode for classifying on the decoded frames of a video. 3) The live stream mode for classifying
objects on a live stream of input data, such as from camera. objects on a live stream of input data, such as from camera.
classifier_options: Options for the image classification task. display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
classifier_options: Optional[_ClassifierOptions] = dataclasses.field( display_names_locale: Optional[str] = None
default_factory=_ClassifierOptions) max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
result_callback: Optional[Callable[ result_callback: Optional[Callable[
[ImageClassifierResult, image_module.Image, int], None]] = None [ImageClassifierResult, image_module.Image, int], None]] = None
@ -80,7 +96,12 @@ class ImageClassifierOptions:
"""Generates an ImageClassifierOptions protobuf object.""" """Generates an ImageClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
classifier_options_proto = self.classifier_options.to_pb2() classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
return _ImageClassifierGraphOptionsProto( return _ImageClassifierGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,

View File

@ -21,9 +21,9 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
ImageEmbedderResult = embedding_result_module.EmbeddingResult ImageEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -62,15 +62,22 @@ class ImageEmbedderOptions:
image on single image inputs. 2) The video mode for embedding image on the image on single image inputs. 2) The video mode for embedding image on the
decoded frames of a video. 3) The live stream mode for embedding image on decoded frames of a video. 3) The live stream mode for embedding image on
a live stream of input data, such as from camera. a live stream of input data, such as from camera.
embedder_options: Options for the image embedder task. l2_normalize: Whether to normalize the returned feature vector with L2 norm.
Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
embedder_options: Optional[_EmbedderOptions] = dataclasses.field( l2_normalize: Optional[bool] = None
default_factory=_EmbedderOptions) quantize: Optional[bool] = None
result_callback: Optional[Callable[ result_callback: Optional[Callable[
[ImageEmbedderResult, image_module.Image, int], None]] = None [ImageEmbedderResult, image_module.Image, int], None]] = None
@ -79,7 +86,8 @@ class ImageEmbedderOptions:
"""Generates an ImageEmbedderOptions protobuf object.""" """Generates an ImageEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
embedder_options_proto = self.embedder_options.to_pb2() embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
return _ImageEmbedderGraphOptionsProto( return _ImageEmbedderGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,