Internal change

PiperOrigin-RevId: 493427500
This commit is contained in:
MediaPipe Team 2022-12-06 15:22:03 -08:00 committed by Copybara-Service
parent fca0f5806b
commit 576c6da173
23 changed files with 162 additions and 212 deletions

View File

@ -29,11 +29,11 @@ py_library(
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
@ -51,11 +51,11 @@ py_library(
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -21,11 +21,11 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.audio.core import base_audio_task_api
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult
_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options_module.ClassifierOptions
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_RunningMode = running_mode_module.AudioTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
@ -62,16 +62,31 @@ class AudioClassifierOptions:
mode for running classification on the audio stream, such as from
microphone. In this mode, the "result_callback" below must be specified
to receive the classification results asynchronously.
classifier_options: Options for configuring the classifier behavior, such as
score threshold, number of results, etc.
display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
result_callback: The user-defined result callback for processing audio
stream data. The result callback should only be specified when the running
mode is set to the audio stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
classifier_options: Optional[_ClassifierOptions] = dataclasses.field(
default_factory=_ClassifierOptions)
display_names_locale: Optional[str] = None
max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None
@doc_controls.do_not_generate_docs
@ -79,7 +94,12 @@ class AudioClassifierOptions:
"""Generates an AudioClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
classifier_options_proto = self.classifier_options.to_pb2()
classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
return _AudioClassifierGraphOptionsProto(
base_options=base_options_proto,

View File

@ -21,11 +21,11 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.audio.core import base_audio_task_api
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
@ -35,7 +35,7 @@ AudioEmbedderResult = embedding_result_module.EmbeddingResult
_AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_RunningMode = running_mode_module.AudioTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
@ -63,16 +63,22 @@ class AudioEmbedderOptions:
stream mode for running embedding extraction on the audio stream, such as
from microphone. In this mode, the "result_callback" below must be
specified to receive the embedding results asynchronously.
embedder_options: Options for configuring the embedder behavior, such as
l2_normalize and quantize.
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
result_callback: The user-defined result callback for processing audio
stream data. The result callback should only be specified when the running
mode is set to the audio stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
embedder_options: Optional[_EmbedderOptions] = dataclasses.field(
default_factory=_EmbedderOptions)
l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None
result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None
@doc_controls.do_not_generate_docs
@ -80,7 +86,8 @@ class AudioEmbedderOptions:
"""Generates an AudioEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
embedder_options_proto = self.embedder_options.to_pb2()
embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
return _AudioEmbedderGraphOptionsProto(
base_options=base_options_proto,

View File

@ -28,12 +28,3 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "embedder_options",
srcs = ["embedder_options.py"],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -15,12 +15,9 @@
"""MediaPipe Tasks Components Processors API."""
import mediapipe.tasks.python.components.processors.classifier_options
import mediapipe.tasks.python.components.processors.embedder_options
ClassifierOptions = classifier_options.ClassifierOptions
EmbedderOptions = embedder_options.EmbedderOptions
# Remove unnecessary modules to avoid duplication in API docs.
del classifier_options
del embedder_options
del mediapipe

View File

@ -1,68 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedder options data class."""
import dataclasses
from typing import Any, Optional
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
@dataclasses.dataclass
class EmbedderOptions:
"""Shared options used by all embedding extraction tasks.
Attributes:
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
"""
l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbedderOptionsProto:
"""Generates a EmbedderOptions protobuf object."""
return _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions':
"""Creates a `EmbedderOptions` object from the given protobuf object."""
return EmbedderOptions(
l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, EmbedderOptions):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -23,8 +23,5 @@ licenses(["notice"])
py_library(
name = "cosine_similarity",
srcs = ["cosine_similarity.py"],
deps = [
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
],
deps = ["//mediapipe/tasks/python/components/containers:embedding_result"],
)

View File

@ -16,10 +16,8 @@
import numpy as np
from mediapipe.tasks.python.components.containers import embedding_result
from mediapipe.tasks.python.components.processors import embedder_options
_Embedding = embedding_result.Embedding
_EmbedderOptions = embedder_options.EmbedderOptions
def _compute_cosine_similarity(u, v):

View File

@ -30,7 +30,6 @@ py_test(
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
],
@ -48,7 +47,6 @@ py_test(
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
],

View File

@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier
from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions
_AudioClassifierResult = classification_result_module.ClassificationResult
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
max_results=1))) as classifier:
max_results=1)) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
score_threshold=0.9))) as classifier:
score_threshold=0.9)) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
category_allowlist=['Speech']))) as classifier:
category_allowlist=['Speech'])) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase):
r'exclusive options.'):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
category_allowlist=['foo'], category_denylist=['bar']))
category_allowlist=['foo'],
category_denylist=['bar'])
with _AudioClassifier.create_from_options(options) as unused_classifier:
pass
@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase):
_AudioClassifierOptions(
base_options=_BaseOptions(
model_asset_path=self.two_heads_model_path),
classifier_options=_ClassifierOptions(
max_results=1))) as classifier:
max_results=1)) as classifier:
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
classifier_options=_ClassifierOptions(max_results=1),
max_results=1,
result_callback=save_result)
classifier = _AudioClassifier.create_from_options(options)
audio_data_list = self._read_wav_file_as_stream(audio_file)

View File

@ -26,7 +26,6 @@ from scipy.io import wavfile
from mediapipe.tasks.python.audio import audio_embedder
from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
@ -35,7 +34,6 @@ _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options.EmbedderOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
_YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite'
@ -172,9 +170,7 @@ class AudioEmbedderTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _AudioEmbedderOptions(
base_options=base_options,
embedder_options=_EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize))
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
with _AudioEmbedder.create_from_options(options) as embedder:
embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0))
@ -291,8 +287,8 @@ class AudioEmbedderTest(parameterized.TestCase):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
embedder_options=_EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize),
l2_normalize=l2_normalize,
quantize=quantize,
result_callback=save_result)
with _AudioEmbedder.create_from_options(options) as embedder:

View File

@ -28,7 +28,6 @@ py_test(
deps = [
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:text_classifier",
@ -44,7 +43,6 @@ py_test(
],
deps = [
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:text_embedder",

View File

@ -21,14 +21,12 @@ from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_classifier
TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
_Classifications = classification_result_module.Classifications
_TextClassifier = text_classifier.TextClassifier

View File

@ -21,13 +21,11 @@ from absl.testing import parameterized
import numpy as np
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_embedder
_BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_Embedding = embedding_result_module.Embedding
_TextEmbedder = text_embedder.TextEmbedder
_TextEmbedderOptions = text_embedder.TextEmbedderOptions
@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _TextEmbedderOptions(
base_options=base_options, embedder_options=embedder_options)
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
embedder = _TextEmbedder.create_from_options(options)
# Extracts both embeddings.
@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _TextEmbedderOptions(
base_options=base_options, embedder_options=embedder_options)
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
with _TextEmbedder.create_from_options(options) as embedder:
# Extracts both embeddings.
positive_text0 = "it's a charming and often affecting journey"

View File

@ -49,7 +49,6 @@ py_test(
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_classifier",
@ -69,7 +68,6 @@ py_test(
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_embedder",

View File

@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_classifier
@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult
_Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category_module.Category
_Classifications = classification_result_module.Classifications
_Image = image.Image
@ -171,9 +169,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
base_options=base_options, max_results=max_results)
classifier = _ImageClassifier.create_from_options(options)
# Performs image classification on the input.
@ -200,9 +197,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
base_options=base_options, max_results=max_results)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -212,9 +208,7 @@ class ImageClassifierTest(parameterized.TestCase):
def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
options = _ImageClassifierOptions(base_options=base_options, max_results=1)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
@ -230,11 +224,9 @@ class ImageClassifierTest(parameterized.TestCase):
_generate_soccer_ball_results().to_pb2())
def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions(
score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
score_threshold=_SCORE_THRESHOLD)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -249,11 +241,9 @@ class ImageClassifierTest(parameterized.TestCase):
f'{classification}')
def test_max_results_option(self):
custom_classifier_options = _ClassifierOptions(
score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
score_threshold=_SCORE_THRESHOLD)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -263,11 +253,9 @@ class ImageClassifierTest(parameterized.TestCase):
len(categories), _MAX_RESULTS, 'Too many results returned.')
def test_allow_list_option(self):
custom_classifier_options = _ClassifierOptions(
category_allowlist=_ALLOW_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
category_allowlist=_ALLOW_LIST)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -280,10 +268,9 @@ class ImageClassifierTest(parameterized.TestCase):
f'Label {label} found but not in label allow list')
def test_deny_list_option(self):
custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
category_denylist=_DENY_LIST)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -301,19 +288,17 @@ class ImageClassifierTest(parameterized.TestCase):
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
custom_classifier_options = _ClassifierOptions(
category_allowlist=['foo'], category_denylist=['bar'])
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
category_allowlist=['foo'],
category_denylist=['bar'])
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_empty_classification_outputs(self):
custom_classifier_options = _ClassifierOptions(score_threshold=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
score_threshold=1)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -386,11 +371,10 @@ class ImageClassifierTest(parameterized.TestCase):
classifier.classify_for_video(self.test_image, 0)
def test_classify_for_video(self):
custom_classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
classifier_options=custom_classifier_options)
max_results=4)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
@ -399,11 +383,10 @@ class ImageClassifierTest(parameterized.TestCase):
_generate_burger_results().to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
classifier_options=custom_classifier_options)
max_results=1)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
@ -439,11 +422,10 @@ class ImageClassifierTest(parameterized.TestCase):
classifier.classify_for_video(self.test_image, 0)
def test_classify_async_calls_with_illegal_timestamp(self):
custom_classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options,
max_results=4,
result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 100)
@ -466,12 +448,11 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
custom_classifier_options = _ClassifierOptions(
max_results=4, score_threshold=threshold)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options,
max_results=4,
score_threshold=threshold,
result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
@ -496,11 +477,10 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options,
max_results=1,
result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):

View File

@ -24,7 +24,6 @@ import numpy as np
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_embedder
@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
_Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_Embedding = embedding_result_module.Embedding
_Image = image_module.Image
_ImageEmbedder = image_embedder.ImageEmbedder
@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _ImageEmbedderOptions(
base_options=base_options, embedder_options=embedder_options)
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
embedder = _ImageEmbedder.create_from_options(options)
image_processing_options = None
@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize)
options = _ImageEmbedderOptions(
base_options=base_options, embedder_options=embedder_options)
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
with _ImageEmbedder.create_from_options(options) as embedder:
# Extracts both embeddings.

View File

@ -28,9 +28,9 @@ py_library(
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
@ -47,9 +47,9 @@ py_library(
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -14,14 +14,14 @@
"""MediaPipe text classifier task."""
import dataclasses
from typing import Optional
from typing import Optional, List
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -30,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api
TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
@ -46,17 +46,38 @@ class TextClassifierOptions:
Attributes:
base_options: Base options for the text classifier task.
classifier_options: Options for the text classification task.
display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
"""
base_options: _BaseOptions
classifier_options: Optional[_ClassifierOptions] = dataclasses.field(
default_factory=_ClassifierOptions)
display_names_locale: Optional[str] = None
max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
"""Generates an TextClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
classifier_options_proto = self.classifier_options.to_pb2()
classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
return _TextClassifierGraphOptionsProto(
base_options=base_options_proto,

View File

@ -19,9 +19,9 @@ from typing import Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
@ -31,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api
TextEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions
_TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_TaskInfo = task_info_module.TaskInfo
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
@ -47,17 +47,25 @@ class TextEmbedderOptions:
Attributes:
base_options: Base options for the text embedder task.
embedder_options: Options for the text embedder task.
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
"""
base_options: _BaseOptions
embedder_options: Optional[_EmbedderOptions] = dataclasses.field(
default_factory=_EmbedderOptions)
l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextEmbedderGraphOptionsProto:
"""Generates an TextEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
embedder_options_proto = self.embedder_options.to_pb2()
embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
return _TextEmbedderGraphOptionsProto(
base_options=base_options_proto,

View File

@ -47,10 +47,10 @@ py_library(
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
@ -89,9 +89,9 @@ py_library(
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -14,17 +14,17 @@
"""MediaPipe image classifier task."""
import dataclasses
from typing import Callable, Mapping, Optional
from typing import Callable, Mapping, Optional, List
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
@ -63,15 +63,31 @@ class ImageClassifierOptions:
objects on single image inputs. 2) The video mode for classifying objects
on the decoded frames of a video. 3) The live stream mode for classifying
objects on a live stream of input data, such as from camera.
classifier_options: Options for the image classification task.
display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
classifier_options: Optional[_ClassifierOptions] = dataclasses.field(
default_factory=_ClassifierOptions)
display_names_locale: Optional[str] = None
max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
result_callback: Optional[Callable[
[ImageClassifierResult, image_module.Image, int], None]] = None
@ -80,7 +96,12 @@ class ImageClassifierOptions:
"""Generates an ImageClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
classifier_options_proto = self.classifier_options.to_pb2()
classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
return _ImageClassifierGraphOptionsProto(
base_options=base_options_proto,

View File

@ -21,9 +21,9 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
ImageEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -62,15 +62,22 @@ class ImageEmbedderOptions:
image on single image inputs. 2) The video mode for embedding image on the
decoded frames of a video. 3) The live stream mode for embedding image on
a live stream of input data, such as from camera.
embedder_options: Options for the image embedder task.
l2_normalize: Whether to normalize the returned feature vector with L2 norm.
Use this option only if the model does not already contain a native
L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and
L2 norm is thus achieved through TF Lite inference.
quantize: Whether the returned embedding should be quantized to bytes via
scalar quantization. Embeddings are implicitly assumed to be unit-norm and
therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
the l2_normalize option if this is not the case.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
embedder_options: Optional[_EmbedderOptions] = dataclasses.field(
default_factory=_EmbedderOptions)
l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None
result_callback: Optional[Callable[
[ImageEmbedderResult, image_module.Image, int], None]] = None
@ -79,7 +86,8 @@ class ImageEmbedderOptions:
"""Generates an ImageEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
embedder_options_proto = self.embedder_options.to_pb2()
embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
return _ImageEmbedderGraphOptionsProto(
base_options=base_options_proto,