From a2bb58382d2b263eee59d959e86f56077e6be3b0 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 14 Nov 2022 01:20:45 -0800 Subject: [PATCH 001/469] Added AudioEmbedder API and tests along with fixing a couple of typos in AudioClassifier --- mediapipe/tasks/python/audio/BUILD | 23 ++ .../tasks/python/audio/audio_classifier.py | 4 +- .../tasks/python/audio/audio_embedder.py | 285 ++++++++++++++++ mediapipe/tasks/python/test/audio/BUILD | 18 + .../python/test/audio/audio_embedder_test.py | 317 ++++++++++++++++++ 5 files changed, 645 insertions(+), 2 deletions(-) create mode 100644 mediapipe/tasks/python/audio/audio_embedder.py create mode 100644 mediapipe/tasks/python/test/audio/audio_embedder_test.py diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index dd8719151..2e5815ff0 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -39,3 +39,26 @@ py_library( "//mediapipe/tasks/python/core:task_info", ], ) + +py_library( + name = "audio_embedder", + srcs = [ + "audio_embedder.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/audio/core:base_audio_task_api", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/components/processors:embedder_options", + "//mediapipe/tasks/python/components/utils:cosine_similarity", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + ], +) diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index e04e778b5..a081e5ecd 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -257,7 +257,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): Raises: ValueError: If any of the followings: 1) The sample rate is not provided in the `AudioData` object or the - provided sample rate is inconsisent with the previously recevied. + provided sample rate is inconsistent with the previously received. 2) The current input timestamp is smaller than what the audio classifier has already processed. """ @@ -270,7 +270,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): elif audio_block.audio_format.sample_rate != self._default_sample_rate: raise ValueError( f'The audio sample rate provided in audio data: ' - f'{audio_block.audio_format.sample_rate} is inconsisent with ' + f'{audio_block.audio_format.sample_rate} is inconsistent with ' f'the previously received: {self._default_sample_rate}.') self._send_audio_stream_data({ diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py new file mode 100644 index 000000000..0580b3518 --- /dev/null +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -0,0 +1,285 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe audio embedder task.""" + +import dataclasses +from typing import Callable, Mapping, List, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import packet +from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 +from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module +from mediapipe.tasks.python.audio.core import base_audio_task_api +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module +from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module +from mediapipe.tasks.python.components.utils import cosine_similarity +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +AudioEmbedderResult = embedding_result_module.EmbeddingResult +_AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_EmbedderOptions = embedder_options_module.EmbedderOptions +_RunningMode = running_mode_module.AudioTaskRunningMode +_TaskInfo = task_info_module.TaskInfo + +_AUDIO_IN_STREAM_NAME = 'audio_in' +_AUDIO_TAG = 'AUDIO' +_EMBEDDINGS_STREAM_NAME = 'embeddings_out' +_EMBEDDINGS_TAG = 'EMBEDDINGS' +_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in' +_SAMPLE_RATE_TAG = 'SAMPLE_RATE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph' +_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME = 'timestamped_embeddings_out' +_TIMESTAMPTED_EMBEDDINGS_TAG = 'TIMESTAMPED_EMBEDDINGS' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class AudioEmbedderOptions: + """Options for the audio embedder task. + + Attributes: + base_options: Base options for the audio embedder task. + running_mode: The running mode of the task. Default to the audio clips mode. + Audio embedder task has two running modes: 1) The audio clips mode for + running embedding extraction on independent audio clips. 2) The audio + stream mode for running embedding extraction on the audio stream, such as + from microphone. In this mode, the "result_callback" below must be + specified to receive the embedding results asynchronously. + embedder_options: Options for configuring the embedder behavior, such as + l2_normalize and quantize. + result_callback: The user-defined result callback for processing audio + stream data. The result callback should only be specified when the running + mode is set to the audio stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS + embedder_options: _EmbedderOptions = _EmbedderOptions() + result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _AudioEmbedderGraphOptionsProto: + """Generates an AudioEmbedderOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True + embedder_options_proto = self.embedder_options.to_pb2() + + return _AudioEmbedderGraphOptionsProto( + base_options=base_options_proto, + embedder_options=embedder_options_proto) + + +class AudioEmbedder(base_audio_task_api.BaseAudioTaskApi): + """Class that performs embedding extraction on audio clips or audio stream.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'AudioEmbedder': + """Creates an `AudioEmbedder` object from a TensorFlow Lite model and the default `AudioEmbedderOptions`. + + Note that the created `AudioEmbedder` instance is in audio clips mode, for + embedding extraction on the independent audio clips. + + Args: + model_path: Path to the model. + + Returns: + `AudioEmbedder` object that's created from the model file and the + default `AudioEmbedderOptions`. + + Raises: + ValueError: If failed to create `AudioEmbedder` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = AudioEmbedderOptions( + base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: AudioEmbedderOptions) -> 'AudioEmbedder': + """Creates the `AudioEmbedder` object from audio embedder options. + + Args: + options: Options for the audio embedder task. + + Returns: + `AudioEmbedder` object that's created from `options`. + + Raises: + ValueError: If failed to create `AudioEmbedder` object from + `AudioEmbedderOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet.Packet]): + timestamp_ms = output_packets[ + _EMBEDDINGS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND + if output_packets[_EMBEDDINGS_STREAM_NAME].is_empty(): + options.result_callback( + AudioEmbedderResult(embeddings=[]), timestamp_ms) + return + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_EMBEDDINGS_STREAM_NAME])) + options.result_callback( + AudioEmbedderResult.create_from_pb2(embedding_result_proto), + timestamp_ms) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]), + ':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME]) + ], + output_streams=[ + ':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_STREAM_NAME]), + ':'.join([ + _TIMESTAMPTED_EMBEDDINGS_TAG, + _TIMESTAMPTED_EMBEDDINGS_STREAM_NAME + ]) + ], + task_options=options) + return cls( + # Audio tasks should not drop input audio due to flow limiting, which + # may cause data inconsistency. + task_info.generate_graph_config(enable_flow_limiting=False), + options.running_mode, + packets_callback if options.result_callback else None) + + def embed(self, audio_clip: _AudioData) -> List[AudioEmbedderResult]: + """Performs embedding extraction on the provided audio clips. + + The audio clip is represented as a MediaPipe AudioData. The method accepts + audio clips with various length and audio sample rate. It's required to + provide the corresponding audio sample rate within the `AudioData` object. + + The input audio clip may be longer than what the model is able to process + in a single inference. When this occurs, the input audio clip is split into + multiple chunks starting at different timestamps. For this reason, this + function returns a vector of EmbeddingResult objects, each associated + ith a timestamp corresponding to the start (in milliseconds) of the chunk + data on which embedding extraction was carried out. + + Args: + audio_clip: MediaPipe AudioData. + + Returns: + An `AudioEmbedderResult` object that contains a list of embedding result + objects, each associated with a timestamp corresponding to the start + (in milliseconds) of the chunk data on which embedding extraction was + carried out. + + Raises: + ValueError: If any of the input arguments is invalid, such as the sample + rate is not provided in the `AudioData` object. + RuntimeError: If audio embedding extraction failed to run. + """ + if not audio_clip.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + output_packets = self._process_audio_clip({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_clip.buffer, transpose=True), + _SAMPLE_RATE_IN_STREAM_NAME: + packet_creator.create_double(audio_clip.audio_format.sample_rate) + }) + output_list = [] + embeddings_proto_list = packet_getter.get_proto_list( + output_packets[_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME]) + for proto in embeddings_proto_list: + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom(proto) + output_list.append( + AudioEmbedderResult.create_from_pb2(embedding_result_proto)) + return output_list + + def embed_async(self, audio_block: _AudioData, timestamp_ms: int) -> None: + """Sends audio data (a block in a continuous audio stream) to perform audio embedding extraction. + + Only use this method when the AudioEmbedder is created with the audio + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input audio data is accepted. The results will be available via the + `result_callback` provided in the `AudioEmbedderOptions`. The + `embed_async` method is designed to process auido stream data such as + microphone input. + + The input audio data may be longer than what the model is able to process + in a single inference. When this occurs, the input audio block is split + into multiple chunks. For this reason, the callback may be called multiple + times (once per chunk) for each call to this function. + + The `result_callback` provides: + - An `AudioEmbedderResult` object that contains a list of + embeddings. + - The input timestamp in milliseconds. + + Args: + audio_block: MediaPipe AudioData. + timestamp_ms: The timestamp of the input audio data in milliseconds. + + Raises: + ValueError: If any of the followings: + 1) The sample rate is not provided in the `AudioData` object or the + provided sample rate is inconsistent with the previously received. + 2) The current input timestamp is smaller than what the audio + embedder has already processed. + """ + if not audio_block.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + if not self._default_sample_rate: + self._default_sample_rate = audio_block.audio_format.sample_rate + self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME, + self._default_sample_rate) + elif audio_block.audio_format.sample_rate != self._default_sample_rate: + raise ValueError( + f'The audio sample rate provided in audio data: ' + f'{audio_block.audio_format.sample_rate} is inconsistent with ' + f'the previously received: {self._default_sample_rate}.') + + self._send_audio_stream_data({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_block.buffer, transpose=True).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) + + @classmethod + def cosine_similarity(cls, u: embedding_result_module.Embedding, + v: embedding_result_module.Embedding) -> float: + """Utility function to compute cosine similarity between two embedding entries. + + May return an InvalidArgumentError if e.g. the feature vectors are + of different types (quantized vs. float), have different sizes, or have a + an L2-norm of 0. + + Args: + u: An embedding entry. + v: An embedding entry. + + Returns: + The cosine similarity for the two embeddings. + + Raises: + ValueError: May return an error if e.g. the feature vectors are of + different types (quantized vs. float), have different sizes, or have + an L2-norm of 0. + """ + return cosine_similarity.cosine_similarity(u, v) diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 863449126..9278cea55 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -35,3 +35,21 @@ py_test( "//mediapipe/tasks/python/test:test_utils", ], ) + +py_test( + name = "audio_embedder_test", + srcs = ["audio_embedder_test.py"], + data = [ + "//mediapipe/tasks/testdata/audio:test_audio_clips", + "//mediapipe/tasks/testdata/audio:test_models", + ], + deps = [ + "//mediapipe/tasks/python/audio:audio_embedder", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/components/processors:embedder_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py new file mode 100644 index 000000000..d085317e6 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -0,0 +1,317 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for audio embedder.""" +import enum +import os +from typing import List, Tuple +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +from scipy.io import wavfile + +from mediapipe.tasks.python.audio import audio_embedder +from mediapipe.tasks.python.audio.core import audio_task_running_mode +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module +from mediapipe.tasks.python.components.processors import embedder_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils + +_AudioEmbedder = audio_embedder.AudioEmbedder +_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions +_AudioEmbedderResult = embedding_result_module.EmbeddingResult +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_EmbedderOptions = embedder_options.EmbedderOptions +_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode + +_YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' +_YAMNET_MODEL_SAMPLE_RATE = 16000 +_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav' +_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav' +_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio' +_SPEECH_SIMILARITIES = [0.985359, 0.994349, 0.993227, 0.996658, 0.996384] +_YAMNET_NUM_OF_SAMPLES = 15600 +_MILLSECONDS_PER_SECOND = 1000 +# Tolerance for embedding vector coordinate values. +_EPSILON = 3e-6 +# Tolerance for cosine similarity evaluation. +_SIMILARITY_TOLERANCE = 1e-6 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class AudioEmbedderTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.yamnet_model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE)) + + def _read_wav_file(self, file_name) -> _AudioData: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + return _AudioData.create_from_array( + buffer.astype(float) / np.iinfo(np.int16).max, sample_rate) + + def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + audio_data_list = [] + start = 0 + step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE + while start < len(buffer): + end = min(start + (int)(step_size), len(buffer)) + audio_data_list.append((_AudioData.create_from_array( + buffer[start:end].astype(float) / np.iinfo(np.int16).max, + sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND))) + start = end + return audio_data_list + + def _check_embedding_value(self, result, expected_first_value): + # Check embedding first value. + self.assertAlmostEqual( + result.embeddings[0].embedding[0], expected_first_value, delta=_EPSILON) + + def _check_embedding_size(self, result, quantize, expected_embedding_size): + # Check embedding size. + self.assertLen(result.embeddings, 1) + embedding_result = result.embeddings[0] + self.assertLen(embedding_result.embedding, expected_embedding_size) + if quantize: + self.assertEqual(embedding_result.embedding.dtype, np.uint8) + else: + self.assertEqual(embedding_result.embedding.dtype, float) + + def _check_cosine_similarity(self, result0, result1, expected_similarity): + # Checks cosine similarity. + similarity = _AudioEmbedder.cosine_similarity(result0.embeddings[0], + result1.embeddings[0]) + self.assertAlmostEqual( + similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) + + # TODO: Compares the exact score values to capture unexpected + # changes in the inference pipeline. + def _check_yamnet_result( + self, + embedding_result0_list: List[_AudioEmbedderResult], + embedding_result1_list: List[_AudioEmbedderResult], + expected_similarities: List[float]): + expected_size = len(expected_similarities) + self.assertLen(embedding_result0_list, expected_size) + self.assertLen(embedding_result1_list, expected_size) + + for idx in range(expected_size): + embedding_result0 = embedding_result0_list[idx] + embedding_result1 = embedding_result1_list[idx] + self._check_cosine_similarity(embedding_result0, embedding_result1, + expected_similarities[idx]) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + with _AudioEmbedder.create_from_options( + _AudioEmbedderOptions( + base_options=_BaseOptions( + model_asset_path=self.yamnet_model_path))) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _AudioEmbedderOptions(base_options=base_options) + _AudioEmbedder.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.yamnet_model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _AudioEmbedderOptions(base_options=base_options) + embedder = _AudioEmbedder.create_from_options(options) + self.assertIsInstance(embedder, _AudioEmbedder) + + @parameterized.parameters( + # Same audio inputs but different sample rates. + (False, False, ModelFileType.FILE_NAME, _SPEECH_WAV_16K_MONO, + _SPEECH_WAV_48K_MONO, 1024, (0, 0)), + (False, False, ModelFileType.FILE_CONTENT, _SPEECH_WAV_16K_MONO, + _SPEECH_WAV_48K_MONO, 1024, (0, 0))) + def test_embed_with_yamnet_model( + self, l2_normalize, quantize, model_file_type, audio_file0, audio_file1, + expected_size, expected_first_values): + # Creates embedder. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.yamnet_model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.yamnet_model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + embedder_options = _EmbedderOptions( + l2_normalize=l2_normalize, quantize=quantize) + options = _AudioEmbedderOptions( + base_options=base_options, embedder_options=embedder_options) + + with _AudioEmbedder.create_from_options(options) as embedder: + embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) + embedding_result1_list = embedder.embed(self._read_wav_file(audio_file1)) + + # Checks embeddings and cosine similarity. + expected_result0_value, expected_result1_value = expected_first_values + self._check_embedding_size(embedding_result0_list[0], quantize, + expected_size) + self._check_embedding_size(embedding_result1_list[0], quantize, + expected_size) + self._check_embedding_value(embedding_result0_list[0], + expected_result0_value) + self._check_embedding_value(embedding_result1_list[0], + expected_result1_value) + self._check_yamnet_result(embedding_result0_list, embedding_result1_list, + expected_similarities=_SPEECH_SIMILARITIES) + + def test_embed_with_yamnet_model_and_different_inputs(self): + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + embedding_result0_list = embedder.embed( + self._read_wav_file(_SPEECH_WAV_16K_MONO)) + embedding_result1_list = embedder.embed( + self._read_wav_file(_TWO_HEADS_WAV_16K_MONO)) + self.assertLen(embedding_result0_list, 5) + self.assertLen(embedding_result1_list, 1) + self._check_cosine_similarity(embedding_result0_list[0], + embedding_result1_list[0], + expected_similarity=0.09017) + + def test_missing_sample_rate_in_audio_clips_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with self.assertRaisesRegex(ValueError, + r'Must provide the audio sample rate'): + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed(_AudioData(buffer_length=100)) + + def test_missing_sample_rate_in_audio_stream_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'provide the audio sample rate in audio data'): + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed(_AudioData(buffer_length=100)) + + def test_missing_result_callback(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _AudioEmbedder.create_from_options(options) as unused_embedder: + pass + + def test_illegal_result_callback(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _AudioEmbedder.create_from_options(options) as unused_embedder: + pass + + def test_calling_embed_in_audio_stream_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioEmbedder.create_from_options(options) as embedder: + with self.assertRaisesRegex(ValueError, + r'not initialized with the audio clips mode'): + embedder.embed(self._read_wav_file(_SPEECH_WAV_16K_MONO)) + + def test_calling_embed_async_in_audio_clips_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with _AudioEmbedder.create_from_options(options) as embedder: + with self.assertRaisesRegex( + ValueError, r'not initialized with the audio stream mode'): + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + def test_embed_async_calls_with_illegal_timestamp(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + @parameterized.parameters( + # Same audio inputs but different sample rates. + (False, False, _SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO)) + def test_embed_async(self, l2_normalize, quantize, audio_file0, audio_file1): + embedding_result_list = [] + embedding_result_list_copy = embedding_result_list.copy() + + def save_result(result: _AudioEmbedderResult, timestamp_ms: int): + result.timestamp_ms = timestamp_ms + embedding_result_list.append(result) + + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + embedder_options=_EmbedderOptions(l2_normalize=l2_normalize, + quantize=quantize), + result_callback=save_result) + + with _AudioEmbedder.create_from_options(options) as embedder: + audio_data0_list = self._read_wav_file_as_stream(audio_file0) + for audio_data, timestamp_ms in audio_data0_list: + embedder.embed_async(audio_data, timestamp_ms) + embedding_result0_list = embedding_result_list + + with _AudioEmbedder.create_from_options(options) as embedder: + audio_data1_list = self._read_wav_file_as_stream(audio_file1) + embedding_result_list = embedding_result_list_copy + for audio_data, timestamp_ms in audio_data1_list: + embedder.embed_async(audio_data, timestamp_ms) + embedding_result1_list = embedding_result_list + + self._check_yamnet_result(embedding_result0_list, embedding_result1_list, + expected_similarities=_SPEECH_SIMILARITIES) + + +if __name__ == '__main__': + absltest.main() From 6610ca72ba82cd2605b698e5cc1a36d25340a441 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 14 Nov 2022 02:25:28 -0800 Subject: [PATCH 002/469] Add ts_declaration rule for OSS PiperOrigin-RevId: 488307893 --- mediapipe/framework/port/build_config.bzl | 39 +++++++++++++++++-- .../tasks/web/components/containers/BUILD | 8 ++-- .../web/components/containers/landmark.d.ts | 4 +- mediapipe/tasks/web/core/BUILD | 6 +-- mediapipe/tasks/web/vision/BUILD | 3 ++ .../tasks/web/vision/gesture_recognizer/BUILD | 24 ++++++++---- .../gesture_recognizer/gesture_recognizer.ts | 4 -- ...ons.ts => gesture_recognizer_options.d.ts} | 0 ...sult.ts => gesture_recognizer_result.d.ts} | 0 .../tasks/web/vision/image_classifier/BUILD | 21 +++++++--- ...tions.ts => image_classifier_options.d.ts} | 0 ...result.ts => image_classifier_result.d.ts} | 0 .../tasks/web/vision/object_detector/BUILD | 18 +++++++-- ...ptions.ts => object_detector_options.d.ts} | 2 +- ..._result.ts => object_detector_result.d.ts} | 0 15 files changed, 95 insertions(+), 34 deletions(-) rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_options.ts => gesture_recognizer_options.d.ts} (100%) rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_result.ts => gesture_recognizer_result.d.ts} (100%) rename mediapipe/tasks/web/vision/image_classifier/{image_classifier_options.ts => image_classifier_options.d.ts} (100%) rename mediapipe/tasks/web/vision/image_classifier/{image_classifier_result.ts => image_classifier_result.d.ts} (100%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_options.ts => object_detector_options.d.ts} (97%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_result.ts => object_detector_result.d.ts} (100%) diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index 80e9bfc4d..8d1e6cbf7 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -214,10 +214,10 @@ def mediapipe_ts_library( """Generate ts_project for MediaPipe open source version. Args: - name: the name of the cc_proto_library. - srcs: the .proto files of the cc_proto_library for Bazel use. + name: the name of the mediapipe_ts_library. + srcs: the .proto files of the mediapipe_ts_library for Bazel use. visibility: visibility of this target. - deps: a list of dependency labels for Bazel use; must be cc_proto_library. + deps: a list of dependency labels for Bazel use. testonly: test only or not. allow_unoptimized_namespaces: ignored, used only internally """ @@ -235,3 +235,36 @@ def mediapipe_ts_library( declaration = True, tsconfig = "//:tsconfig.json", )) + +def mediapipe_ts_declaration( + name, + srcs, + visibility = None, + deps = []): + """Generate ts_declaration for MediaPipe open source version. + + Args: + name: the name of the mediapipe_ts_declaration. + srcs: the .proto files of the mediapipe_ts_declaration for Bazel use. + visibility: visibility of this target. + deps: a list of dependency labels for Bazel use + """ + + # Bazel does not create JS files for .d.ts files, which leads to import + # failures in our open source build. We simply re-name the .d.ts files + # to .ts to work around this problem. + for src in srcs: + native.genrule( + name = replace_suffix(src, ".d.ts", "_d_ts"), + srcs = [src], + outs = [replace_suffix(src, ".d.ts", ".ts")], + visibility = visibility, + cmd = "cp -n $< $@;", + ) + + mediapipe_ts_library( + name = name, + srcs = [replace_suffix(src, ".d.ts", "_d_ts") for src in srcs], + visibility = visibility, + deps = deps, + ) diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 1b0e403ff..3b4fe4ef9 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -1,21 +1,21 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "category", srcs = ["category.d.ts"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "classification_result", srcs = ["classification_result.d.ts"], deps = [":category"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "landmark", srcs = ["landmark.d.ts"], ) diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index f790d8a0b..0c0799074 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -18,9 +18,9 @@ * Landmark represents a point in 3D space with x, y, z coordinates. If * normalized is true, the landmark coordinates is normalized respect to the * dimension of image, and the coordinates values are in the range of [0,1]. - * Otherwise, it represenet a point in world coordinates. + * Otherwise, it represents a point in world coordinates. */ -export declare class Landmark { +export declare interface Landmark { /** The x coordinates of the landmark. */ x: number; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 4fb57d6c3..158f5e05f 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,10 +1,10 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "core", srcs = [ "base_options.d.ts", @@ -24,7 +24,7 @@ mediapipe_ts_library( ], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "classifier_options", srcs = [ "classifier_options.d.ts", diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index abdbc54ea..279a1f197 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -9,7 +9,10 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/gesture_recognizer:gesture_recognizer_types", "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_classifier:image_classifier_types", "//mediapipe/tasks/web/vision/object_detector", + "//mediapipe/tasks/web/vision/object_detector:object_detector_types", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 6b99f6ce4..7ed04a5b9 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -3,7 +3,7 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more gesture categories, using Gesture Recognizer. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,12 +11,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", - srcs = [ - "gesture_recognizer.ts", - "gesture_recognizer_options.ts", - "gesture_recognizer_result.ts", - ], + srcs = ["gesture_recognizer.ts"], deps = [ + ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", @@ -33,8 +30,21 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "gesture_recognizer_types", + srcs = [ + "gesture_recognizer_options.d.ts", + "gesture_recognizer_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index c24d1a7b3..b06fbf371 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -154,10 +154,6 @@ export class GestureRecognizer extends TaskRunner { this.handGestureRecognizerGraphOptions); this.initDefaults(); - - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); } /** diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index e96d6a8e3..4d9559cfc 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -2,7 +2,7 @@ # # This task takes video or image frames and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -12,22 +12,31 @@ mediapipe_ts_library( name = "image_classifier", srcs = [ "image_classifier.ts", - "image_classifier_options.ts", - "image_classifier_result.ts", ], deps = [ + ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", - "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "image_classifier_types", + srcs = [ + "image_classifier_options.d.ts", + "image_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts rename to mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts rename to mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 095a84b52..2ce701b17 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -3,7 +3,7 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more object categories, using Object Detector. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -13,18 +13,28 @@ mediapipe_ts_library( name = "object_detector", srcs = [ "object_detector.ts", - "object_detector_options.ts", - "object_detector_result.ts", ], deps = [ + ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:detection_jspb_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", - "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "object_detector_types", + srcs = [ + "object_detector_options.d.ts", + "object_detector_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/core", + ], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts similarity index 97% rename from mediapipe/tasks/web/vision/object_detector/object_detector_options.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts index eec12cf17..3eb7df986 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -17,7 +17,7 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; /** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions { +export declare interface ObjectDetectorOptions { /** Options to configure the loading of the model assets. */ baseOptions?: BaseOptions; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/object_detector/object_detector_result.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts From bc6240e989490ba5650834861a2a7efe4cf06ee2 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Mon, 14 Nov 2022 02:29:30 -0800 Subject: [PATCH 003/469] Zero-initialize id etc. members in Tensor PiperOrigin-RevId: 488308585 --- mediapipe/framework/formats/tensor.cc | 4 ++++ mediapipe/framework/formats/tensor.h | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index ef0cddea4..c31eba350 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -387,7 +387,9 @@ void Tensor::Move(Tensor* src) { src->cpu_buffer_ = nullptr; #if MEDIAPIPE_METAL_ENABLED device_ = src->device_; + src->device_ = nil; command_buffer_ = src->command_buffer_; + src->command_buffer_ = nil; metal_buffer_ = src->metal_buffer_; src->metal_buffer_ = nil; #endif // MEDIAPIPE_METAL_ENABLED @@ -431,6 +433,8 @@ void Tensor::Invalidate() { DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); } metal_buffer_ = nil; + command_buffer_ = nil; + device_ = nil; cpu_buffer_ = nullptr; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 // Don't need to wait for the resource to be deleted bacause if will be diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index ff9da3ec6..ecd63c8c6 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -384,9 +384,9 @@ class Tensor { mutable void* cpu_buffer_ = nullptr; void AllocateCpuBuffer() const; #if MEDIAPIPE_METAL_ENABLED - mutable id command_buffer_; - mutable id device_; - mutable id metal_buffer_; + mutable id command_buffer_ = nil; + mutable id device_ = nil; + mutable id metal_buffer_ = nil; void AllocateMtlBuffer(id device) const; #endif // MEDIAPIPE_METAL_ENABLED From badaccfb04762bd4b20970a3843041a177683f5b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 14 Nov 2022 05:03:40 -0800 Subject: [PATCH 004/469] Internal change PiperOrigin-RevId: 488333493 --- mediapipe/framework/port/build_config.bzl | 39 ++----------------- .../tasks/web/components/containers/BUILD | 8 ++-- .../web/components/containers/landmark.d.ts | 4 +- mediapipe/tasks/web/core/BUILD | 6 +-- mediapipe/tasks/web/vision/BUILD | 3 -- .../tasks/web/vision/gesture_recognizer/BUILD | 24 ++++-------- .../gesture_recognizer/gesture_recognizer.ts | 4 ++ ...ons.d.ts => gesture_recognizer_options.ts} | 0 ...sult.d.ts => gesture_recognizer_result.ts} | 0 .../tasks/web/vision/image_classifier/BUILD | 21 +++------- ...tions.d.ts => image_classifier_options.ts} | 0 ...result.d.ts => image_classifier_result.ts} | 0 .../tasks/web/vision/object_detector/BUILD | 18 ++------- ...ptions.d.ts => object_detector_options.ts} | 2 +- ..._result.d.ts => object_detector_result.ts} | 0 15 files changed, 34 insertions(+), 95 deletions(-) rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_options.d.ts => gesture_recognizer_options.ts} (100%) rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_result.d.ts => gesture_recognizer_result.ts} (100%) rename mediapipe/tasks/web/vision/image_classifier/{image_classifier_options.d.ts => image_classifier_options.ts} (100%) rename mediapipe/tasks/web/vision/image_classifier/{image_classifier_result.d.ts => image_classifier_result.ts} (100%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_options.d.ts => object_detector_options.ts} (97%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_result.d.ts => object_detector_result.ts} (100%) diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index 8d1e6cbf7..80e9bfc4d 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -214,10 +214,10 @@ def mediapipe_ts_library( """Generate ts_project for MediaPipe open source version. Args: - name: the name of the mediapipe_ts_library. - srcs: the .proto files of the mediapipe_ts_library for Bazel use. + name: the name of the cc_proto_library. + srcs: the .proto files of the cc_proto_library for Bazel use. visibility: visibility of this target. - deps: a list of dependency labels for Bazel use. + deps: a list of dependency labels for Bazel use; must be cc_proto_library. testonly: test only or not. allow_unoptimized_namespaces: ignored, used only internally """ @@ -235,36 +235,3 @@ def mediapipe_ts_library( declaration = True, tsconfig = "//:tsconfig.json", )) - -def mediapipe_ts_declaration( - name, - srcs, - visibility = None, - deps = []): - """Generate ts_declaration for MediaPipe open source version. - - Args: - name: the name of the mediapipe_ts_declaration. - srcs: the .proto files of the mediapipe_ts_declaration for Bazel use. - visibility: visibility of this target. - deps: a list of dependency labels for Bazel use - """ - - # Bazel does not create JS files for .d.ts files, which leads to import - # failures in our open source build. We simply re-name the .d.ts files - # to .ts to work around this problem. - for src in srcs: - native.genrule( - name = replace_suffix(src, ".d.ts", "_d_ts"), - srcs = [src], - outs = [replace_suffix(src, ".d.ts", ".ts")], - visibility = visibility, - cmd = "cp -n $< $@;", - ) - - mediapipe_ts_library( - name = name, - srcs = [replace_suffix(src, ".d.ts", "_d_ts") for src in srcs], - visibility = visibility, - deps = deps, - ) diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 3b4fe4ef9..1b0e403ff 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -1,21 +1,21 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_declaration( +mediapipe_ts_library( name = "category", srcs = ["category.d.ts"], ) -mediapipe_ts_declaration( +mediapipe_ts_library( name = "classification_result", srcs = ["classification_result.d.ts"], deps = [":category"], ) -mediapipe_ts_declaration( +mediapipe_ts_library( name = "landmark", srcs = ["landmark.d.ts"], ) diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index 0c0799074..f790d8a0b 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -18,9 +18,9 @@ * Landmark represents a point in 3D space with x, y, z coordinates. If * normalized is true, the landmark coordinates is normalized respect to the * dimension of image, and the coordinates values are in the range of [0,1]. - * Otherwise, it represents a point in world coordinates. + * Otherwise, it represenet a point in world coordinates. */ -export declare interface Landmark { +export declare class Landmark { /** The x coordinates of the landmark. */ x: number; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 158f5e05f..4fb57d6c3 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,10 +1,10 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_declaration( +mediapipe_ts_library( name = "core", srcs = [ "base_options.d.ts", @@ -24,7 +24,7 @@ mediapipe_ts_library( ], ) -mediapipe_ts_declaration( +mediapipe_ts_library( name = "classifier_options", srcs = [ "classifier_options.d.ts", diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 279a1f197..abdbc54ea 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -9,10 +9,7 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/gesture_recognizer:gesture_recognizer_types", "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_classifier:image_classifier_types", "//mediapipe/tasks/web/vision/object_detector", - "//mediapipe/tasks/web/vision/object_detector:object_detector_types", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 7ed04a5b9..6b99f6ce4 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -3,7 +3,7 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more gesture categories, using Gesture Recognizer. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,9 +11,12 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", - srcs = ["gesture_recognizer.ts"], + srcs = [ + "gesture_recognizer.ts", + "gesture_recognizer_options.ts", + "gesture_recognizer_result.ts", + ], deps = [ - ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", @@ -30,21 +33,8 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) - -mediapipe_ts_declaration( - name = "gesture_recognizer_types", - srcs = [ - "gesture_recognizer_options.d.ts", - "gesture_recognizer_result.d.ts", - ], - deps = [ - "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:classifier_options", - ], -) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index b06fbf371..c24d1a7b3 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -154,6 +154,10 @@ export class GestureRecognizer extends TaskRunner { this.handGestureRecognizerGraphOptions); this.initDefaults(); + + // Disables the automatic render-to-screen code, which allows for pure + // CPU processing. + this.setAutoRenderToScreen(false); } /** diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 4d9559cfc..e96d6a8e3 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -2,7 +2,7 @@ # # This task takes video or image frames and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -12,31 +12,22 @@ mediapipe_ts_library( name = "image_classifier", srcs = [ "image_classifier.ts", + "image_classifier_options.ts", + "image_classifier_result.ts", ], deps = [ - ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) - -mediapipe_ts_declaration( - name = "image_classifier_types", - srcs = [ - "image_classifier_options.d.ts", - "image_classifier_result.d.ts", - ], - deps = [ - "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/core:classifier_options", - ], -) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts rename to mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts rename to mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 2ce701b17..095a84b52 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -3,7 +3,7 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more object categories, using Object Detector. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -13,28 +13,18 @@ mediapipe_ts_library( name = "object_detector", srcs = [ "object_detector.ts", + "object_detector_options.ts", + "object_detector_result.ts", ], deps = [ - ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:detection_jspb_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) - -mediapipe_ts_declaration( - name = "object_detector_types", - srcs = [ - "object_detector_options.d.ts", - "object_detector_result.d.ts", - ], - deps = [ - "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/core", - ], -) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts similarity index 97% rename from mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_options.ts index 3eb7df986..eec12cf17 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts @@ -17,7 +17,7 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; /** Options to configure the MediaPipe Object Detector Task */ -export declare interface ObjectDetectorOptions { +export interface ObjectDetectorOptions { /** Options to configure the loading of the model assets. */ baseOptions?: BaseOptions; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.ts similarity index 100% rename from mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_result.ts From 05cb40ff79fd31bfec226966161ba54ad96027b7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 14 Nov 2022 08:31:32 -0800 Subject: [PATCH 005/469] MediaPipe TextEmbedder task for web PiperOrigin-RevId: 488373613 --- .../tasks/web/components/containers/BUILD | 5 + .../containers/embedding_result.d.ts | 66 +++++++ .../tasks/web/components/processors/BUILD | 22 ++- .../components/processors/embedder_options.ts | 46 +++++ .../components/processors/embedder_result.ts | 53 ++++++ mediapipe/tasks/web/core/BUILD | 8 + .../tasks/web/core/embedder_options.d.ts | 39 ++++ mediapipe/tasks/web/text/text_embedder/BUILD | 32 ++++ .../web/text/text_embedder/text_embedder.ts | 173 ++++++++++++++++++ .../text_embedder/text_embedder_options.d.ts | 17 ++ .../text_embedder/text_embedder_result.d.ts | 17 ++ 11 files changed, 477 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/web/components/containers/embedding_result.d.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_options.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_result.ts create mode 100644 mediapipe/tasks/web/core/embedder_options.d.ts create mode 100644 mediapipe/tasks/web/text/text_embedder/BUILD create mode 100644 mediapipe/tasks/web/text/text_embedder/text_embedder.ts create mode 100644 mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts create mode 100644 mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 1b0e403ff..d1bc480db 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -19,3 +19,8 @@ mediapipe_ts_library( name = "landmark", srcs = ["landmark.d.ts"], ) + +mediapipe_ts_library( + name = "embedding_result", + srcs = ["embedding_result.d.ts"], +) diff --git a/mediapipe/tasks/web/components/containers/embedding_result.d.ts b/mediapipe/tasks/web/components/containers/embedding_result.d.ts new file mode 100644 index 000000000..e1efd94ce --- /dev/null +++ b/mediapipe/tasks/web/components/containers/embedding_result.d.ts @@ -0,0 +1,66 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * List of embeddings with an optional timestamp. + * + * One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will + * contain data, based on whether or not the embedder was configured to perform + * scalar quantization. + */ +export interface Embedding { + /** + * Floating-point embedding. Empty if the embedder was configured to perform + * scalar-quantization. + */ + floatEmbedding?: number[]; + + /** + * Scalar-quantized embedding. Empty if the embedder was not configured to + * perform scalar quantization. + */ + quantizedEmbedding?: Uint8Array; + /** + * The index of the classifier head these categories refer to. This is + * useful for multi-head models. + */ + headIndex: number; + + /** + * The name of the classifier head, which is the corresponding tensor + * metadata name. + */ + headName: string; +} + +/** Embedding results for a given embedder model. */ +export interface EmbeddingResult { + /** + * The embedding results for each model head, i.e. one for each output tensor. + */ + embeddings: Embedding[]; + + /** + * The optional timestamp (in milliseconds) of the start of the chunk of + * data corresponding to these results. + * + * This is only used for embedding extraction on time series (e.g. audio + * embedding). In these use cases, the amount of data to process might + * exceed the maximum size that the model can process: to solve this, the + * input data is split into multiple chunks starting at different timestamps. + */ + timestampMs?: number; +} diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index e0d84b632..1b56bf4c9 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -23,9 +23,29 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_result", + srcs = ["embedder_result.ts"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +mediapipe_ts_library( + name = "embedder_options", + srcs = ["embedder_options.ts"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + mediapipe_ts_library( name = "base_options", - srcs = ["base_options.ts"], + srcs = [ + "base_options.ts", + ], deps = [ "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", diff --git a/mediapipe/tasks/web/components/processors/embedder_options.ts b/mediapipe/tasks/web/components/processors/embedder_options.ts new file mode 100644 index 000000000..f000dbd64 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.ts @@ -0,0 +1,46 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +/** + * Converts a EmbedderOptions object to its Proto representation, optionally + * based on existing definition. + * @param options The options object to convert to a Proto. Only options that + * are expliclty provided are set. + * @param baseOptions A base object that options can be merged into. + */ +export function convertEmbedderOptionsToProto( + options: EmbedderOptions, + baseOptions?: EmbedderOptionsProto): EmbedderOptionsProto { + const embedderOptions = + baseOptions ? baseOptions.clone() : new EmbedderOptionsProto(); + + if (options.l2Normalize !== undefined) { + embedderOptions.setL2Normalize(options.l2Normalize); + } else if ('l2Normalize' in options) { // Check for undefined + embedderOptions.clearL2Normalize(); + } + + if (options.quantize !== undefined) { + embedderOptions.setQuantize(options.quantize); + } else if ('quantize' in options) { // Check for undefined + embedderOptions.clearQuantize(); + } + + return embedderOptions; +} diff --git a/mediapipe/tasks/web/components/processors/embedder_result.ts b/mediapipe/tasks/web/components/processors/embedder_result.ts new file mode 100644 index 000000000..285afe68a --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.ts @@ -0,0 +1,53 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Embedding as EmbeddingProto, EmbeddingResult as EmbeddingResultProto} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {Embedding, EmbeddingResult} from '../../../../tasks/web/components/containers/embedding_result'; + +const DEFAULT_INDEX = -1; + +/** + * Converts an Embedding proto to the Embedding object. + */ +function convertFromEmbeddingsProto(source: EmbeddingProto): Embedding { + const embedding: Embedding = { + headIndex: source.getHeadIndex() ?? DEFAULT_INDEX, + headName: source.getHeadName() ?? '', + }; + + if (source.hasFloatEmbedding()) { + embedding.floatEmbedding = source.getFloatEmbedding()!.getValuesList(); + } else { + const encodedValue = source.getQuantizedEmbedding()?.getValues() ?? ''; + embedding.quantizedEmbedding = typeof encodedValue == 'string' ? + Uint8Array.from(atob(encodedValue), c => c.charCodeAt(0)) : encodedValue; + } + + return embedding; +} + +/** + * Converts an EmbedderResult proto to an EmbeddingResult object. + */ +export function convertFromEmbeddingResultProto( + embeddingResult: EmbeddingResultProto): EmbeddingResult { + const result: EmbeddingResult = { + embeddings: embeddingResult.getEmbeddingsList().map( + e => convertFromEmbeddingsProto(e)), + timestampMs: embeddingResult.getTimestampMs(), + }; + return result; +} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 4fb57d6c3..edfc1e5c5 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -31,3 +31,11 @@ mediapipe_ts_library( ], deps = [":core"], ) + +mediapipe_ts_library( + name = "embedder_options", + srcs = [ + "embedder_options.d.ts", + ], + deps = [":core"], +) diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts new file mode 100644 index 000000000..78ddad1ae --- /dev/null +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -0,0 +1,39 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../tasks/web/core/base_options'; + +/** Options to configure the MediaPipe Embedder Task */ +export declare interface EmbedderOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * Whether to normalize the returned feature vector with L2 norm. Use this + * option only if the model does not already contain a native L2_NORMALIZATION + * TF Lite Op. In most cases, this is already the case and L2 norm is thus + * achieved through TF Lite inference. + */ + l2Normalize?: boolean|undefined; + + /** + * Whether the returned embedding should be quantized to bytes via scalar + * quantization. Embeddings are implicitly assumed to be unit-norm and + * therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + * the l2_normalize option if this is not the case. + */ + quantize?: boolean|undefined; +} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD new file mode 100644 index 000000000..8e397ce6f --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -0,0 +1,32 @@ +# This contains the MediaPipe Text Embedder Task. +# +# This task takes text input and performs embedding +# + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "text_embedder", + srcs = [ + "text_embedder.ts", + "text_embedder_options.d.ts", + "text_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts new file mode 100644 index 000000000..65df5df6a --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -0,0 +1,173 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {TextEmbedderOptions} from './text_embedder_options'; +import {TextEmbedderResult} from './text_embedder_result'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const INPUT_STREAM = 'text_in'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TEXT_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + +/** + * Performs embedding extraction on text. + */ +export class TextEmbedder extends TaskRunner { + private embeddingResult: TextEmbedderResult = {embeddings: []}; + private readonly options = new TextEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new text embedder from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param textEmbedderOptions The options for the text embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + textEmbedderOptions: TextEmbedderOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const embedder = await createMediaPipeLib( + TextEmbedder, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await embedder.setOptions(textEmbedderOptions); + return embedder; + } + + /** + * Initializes the Wasm runtime and creates a new text embedder based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return TextEmbedder.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new text embedder based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return TextEmbedder.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the text embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the text embedder. + */ + async setOptions(options: TextEmbedderOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + + this.refreshGraph(); + } + + + /** + * Performs embeding extraction on the provided text and waits synchronously + * for the response. + * + * @param text The text to process. + * @return The embedding resuls of the text + */ + embed(text: string): TextEmbedderResult { + // Get text embeddings by running our MediaPipe graph. + this.addStringToStream( + text, INPUT_STREAM, /* timestamp= */ performance.now()); + this.finishProcessing(); + return this.embeddingResult; + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + TextEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('TEXT:' + INPUT_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts new file mode 100644 index 000000000..9af263765 --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {EmbedderOptions as TextEmbedderOptions} from '../../../../tasks/web/core/embedder_options'; diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts new file mode 100644 index 000000000..65640b507 --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Embedding, EmbeddingResult as TextEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; From c7d531ebb2b97a56a11f469cf59719f9892dbc9f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 14 Nov 2022 10:38:38 -0800 Subject: [PATCH 006/469] AddTarget -> ConnectTo and documentation PiperOrigin-RevId: 488407930 --- mediapipe/framework/api2/builder.h | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 5af9ee5e0..6d3323b97 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -176,22 +176,40 @@ class SourceImpl { : SourceImpl(&GetWithAutoGrow(vec, 0)) {} explicit SourceImpl(SourceBase* base) : base_(base) {} + // Connects MediaPipe stream or side packet to a destination: + // - node input (input stream) / side input (input side packet) + // - graph output (output stream) / side output (output side packet). + // + // MediaPipe streams and side packets can be connected to multiple + // destinations. Side packets and packets added to streams are sent to all + // connected destinations. template {}, int>::type = 0> - Src& AddTarget(const Dst& dest) { + Src& ConnectTo(const Dst& dest) { CHECK(dest.base_.source == nullptr); dest.base_.source = base_; base_->dests_.emplace_back(&dest.base_); return *this; } + + // Shortcut for `ConnectTo`. + // + // Connects MediaPipe stream or side packet to a destination: + // - node input (input stream) / side input (input side packet) + // - graph output (output stream) / side output (output side packet). + // + // MediaPipe streams and side packets can be connected to multiple + // destinations. Side packets and packets added to streams are sent to all + // connected destinations. + template + Src& operator>>(const Dst& dest) { + return ConnectTo(dest); + } + Src& SetName(std::string name) { base_->name_ = std::move(name); return *this; } - template - Src& operator>>(const Dst& dest) { - return AddTarget(dest); - } template {}, int> = 0> From 4b5c3521af8789554978452d6cee1550fe279cf1 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 10:42:06 -0800 Subject: [PATCH 007/469] Dividing the timestamp by 1000 when returning a "none" result object from GestureRecognizer and HandLandmarker APIs. PiperOrigin-RevId: 488409077 --- .../tasks/vision/gesturerecognizer/GestureRecognizer.java | 4 +++- .../mediapipe/tasks/vision/handlandmarker/HandLandmarker.java | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 7cbedb32e..e9e74a067 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -164,7 +164,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi { new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), - packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp()); + BaseVisionTaskApi.generateResultTimestampMs( + recognizerOptions.runningMode(), + packets.get(HAND_GESTURES_OUT_STREAM_INDEX))); } return GestureRecognizerResult.create( PacketGetter.getProtoVector( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java index 9be489bbe..a9270d347 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java @@ -156,7 +156,8 @@ public final class HandLandmarker extends BaseVisionTaskApi { new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), - packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp()); + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); } return HandLandmarkerResult.create( PacketGetter.getProtoVector( From b40b2ade140740c667fda92f624d0b941d41b978 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 11:06:13 -0800 Subject: [PATCH 008/469] Fix typos. PiperOrigin-RevId: 488416345 --- .../google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java | 2 +- mediapipe/tasks/python/audio/audio_classifier.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index affe43559..8eaf0adcb 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -91,7 +91,7 @@ public class BaseAudioTaskApi implements AutoCloseable { * * @param sampleRate the audio sample rate. * @throws MediaPipeException if the task is not in the audio stream mode or the provided sample - * rate is inconsisent with the previously recevied. + * rate is inconsistent with the previously received. */ protected void checkOrSetSampleRate(double sampleRate) { if (runningMode != RunningMode.AUDIO_STREAM) { diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index e04e778b5..a081e5ecd 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -257,7 +257,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): Raises: ValueError: If any of the followings: 1) The sample rate is not provided in the `AudioData` object or the - provided sample rate is inconsisent with the previously recevied. + provided sample rate is inconsistent with the previously received. 2) The current input timestamp is smaller than what the audio classifier has already processed. """ @@ -270,7 +270,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): elif audio_block.audio_format.sample_rate != self._default_sample_rate: raise ValueError( f'The audio sample rate provided in audio data: ' - f'{audio_block.audio_format.sample_rate} is inconsisent with ' + f'{audio_block.audio_format.sample_rate} is inconsistent with ' f'the previously received: {self._default_sample_rate}.') self._send_audio_stream_data({ From 34daba4747bcb598c576381a5be8e896e7761fc8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 14 Nov 2022 11:46:37 -0800 Subject: [PATCH 009/469] Add Java TextEmbedder API. PiperOrigin-RevId: 488427327 --- .../proto/text_embedder_graph_options.proto | 3 + .../mediapipe/tasks/core/TaskOptions.java | 4 +- .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + .../com/google/mediapipe/tasks/text/BUILD | 28 ++ .../text/textembedder/AndroidManifest.xml | 8 + .../tasks/text/textembedder/TextEmbedder.java | 256 ++++++++++++++++++ .../text/textembedder/TextEmbedderResult.java | 54 ++++ .../textclassifier/TextClassifierTest.java | 4 +- .../text/textembedder/AndroidManifest.xml | 24 ++ .../mediapipe/tasks/text/textembedder/BUILD | 19 ++ .../text/textembedder/TextEmbedderTest.java | 98 +++++++ 11 files changed, 494 insertions(+), 5 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index 6b8d41a57..e7e3a63c7 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.text.textembedder.proto"; +option java_outer_classname = "TextEmbedderGraphOptionsProto"; + message TextEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional TextEmbedderGraphOptions ext = 477589892; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java index 9bf600360..0fc48742e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java @@ -58,8 +58,8 @@ public abstract class TaskOptions { AccelerationProto.Acceleration.newBuilder(); switch (options.delegate()) { case CPU: - accelerationBuilder.setXnnpack( - InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack + accelerationBuilder.setTflite( + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite .getDefaultInstance()); break; case GPU: diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index f0c9f81c6..ab7ad6616 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -49,6 +49,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_embedder_graph_options_java_proto_lite", ] def mediapipe_tasks_core_aar(name, srcs, manifest): diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index b49169529..0e72878ab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -24,6 +24,7 @@ cc_binary( deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], ) @@ -60,6 +61,33 @@ android_library( ], ) +android_library( + name = "textembedder", + srcs = [ + "textembedder/TextEmbedder.java", + "textembedder/TextEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "textembedder/AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar") mediapipe_tasks_text_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml new file mode 100644 index 000000000..d9c885d16 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java new file mode 100644 index 000000000..95fa1f087 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -0,0 +1,256 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.text.textembedder.proto.TextEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Performs embedding extraction on text. + * + *

This API expects a TFLite model with (optional) TFLite Model Metadata. + * + *

Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + *

    + *
  • Input tensors + *
      + *
    • Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x + * bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input + * signature requires a Bert Tokenizer process unit in the model metadata. + *
    • Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x + * max_seq_len]} representing the input ids. This input signature requires a Regex + * Tokenizer process unit in the model metadata. + *
    • Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code + * [1]} containing the input string. + *
    + *
  • At least one output tensor ({@code kTfLiteFloat32}/{@code kTfLiteUint8}) with shape {@code + * [1 x N]} where N is the number of dimensions in the produced embeddings. + *
+ */ +public final class TextEmbedder implements AutoCloseable { + private static final String TAG = TextEmbedder.class.getSimpleName(); + private static final String TEXT_IN_STREAM_NAME = "text_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME)); + + @SuppressWarnings("ConstantCaseForConstants") + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out")); + + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + private final TaskRunner runner; + + static { + System.loadLibrary("mediapipe_tasks_text_jni"); + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates a {@link TextEmbedder} instance from a model file and the default {@link + * TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the text model with metadata in the assets. + * @throws MediaPipeException if there is is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link TextEmbedder} instance from a model file and the default {@link + * TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the text model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link TextEmbedder} instance from {@link TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param options a {@link TextEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromOptions(Context context, TextEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public TextEmbedderResult convertToTaskResult(List packets) { + try { + return TextEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp()); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(false) + .build(), + handler); + return new TextEmbedder(runner); + } + + /** + * Constructor to initialize a {@link TextEmbedder} from a {@link TaskRunner}. + * + * @param runner a {@link TaskRunner}. + */ + private TextEmbedder(TaskRunner runner) { + this.runner = runner; + } + + /** + * Performs embedding extraction on the input text. + * + * @param inputText a {@link String} for processing. + */ + public TextEmbedderResult embed(String inputText) { + Map inputPackets = new HashMap<>(); + inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); + return (TextEmbedderResult) runner.process(inputPackets); + } + + /** Closes and cleans up the {@link TextEmbedder}. */ + @Override + public void close() { + runner.close(); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up a {@link TextEmbedder}. */ + @AutoValue + public abstract static class TextEmbedderOptions extends TaskOptions { + + /** Builder for {@link TextEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the text embedder task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as + * L2-normalization and scalar quantization. + */ + public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + + public abstract TextEmbedderOptions build(); + } + + abstract BaseOptions baseOptions(); + + abstract Optional embedderOptions(); + + public static Builder builder() { + return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder(); + } + + /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (embedderOptions().isPresent()) { + taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java new file mode 100644 index 000000000..9d8e108ec --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link TextEmbedder}. */ +@AutoValue +public abstract class TextEmbedderResult implements TaskResult { + + /** + * Creates an {@link TextEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static TextEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_TextEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link TextEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static TextEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index d3f0e90f3..5e03d2a4c 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -67,9 +67,7 @@ public class TextClassifierTest { ApplicationProvider.getApplicationContext(), options)); // TODO: Make MediaPipe InferenceCalculator report the detailed. // interpreter errors (e.g., "Encountered unresolved custom op"). - assertThat(exception) - .hasMessageThat() - .contains("interpreter_builder(&interpreter) == kTfLiteOk"); + assertThat(exception).hasMessageThat().contains("== kTfLiteOk"); } @Test diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml new file mode 100644 index 000000000..5d55d7cfe --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java new file mode 100644 index 000000000..b6d53c94d --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java @@ -0,0 +1,98 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Test for {@link TextEmbedder}/ */ +@RunWith(AndroidJUnit4.class) +public class TextEmbedderTest { + private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite"; + private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + private static final float FLOAT_DIFF_TOLERANCE = 1e-4f; + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + TextEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void embed_succeedsWithBert() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); + + TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey"); + assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(20.59746f); + TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip"); + assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(21.774776f); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879); + } + + @Test + public void embed_succeedsWithRegex() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); + + TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey"); + assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(0.030935612f); + TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip"); + assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(0.0312863f); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937); + } +} From b00236e86e00105145e47bef8498b1b715f6bf36 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 14 Nov 2022 12:11:40 -0800 Subject: [PATCH 010/469] Hand Landmarker Web API PiperOrigin-RevId: 488434079 --- mediapipe/tasks/testdata/vision/BUILD | 1 + mediapipe/tasks/web/vision/BUILD | 1 + .../tasks/web/vision/hand_landmarker/BUILD | 35 ++ .../vision/hand_landmarker/hand_landmarker.ts | 319 ++++++++++++++++++ .../hand_landmarker_options.ts | 47 +++ .../hand_landmarker/hand_landmarker_result.ts | 32 ++ mediapipe/tasks/web/vision/index.ts | 5 + 7 files changed, 440 insertions(+) create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/BUILD create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.ts create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.ts diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ad8072b87..95b721fdb 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -40,6 +40,7 @@ mediapipe_files(srcs = [ "fist.jpg", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", + "hand_landmarker.task", "left_hands.jpg", "left_hands_rotated.jpg", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index abdbc54ea..395860892 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -9,6 +9,7 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", "//mediapipe/tasks/web/vision/object_detector", ], diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD new file mode 100644 index 000000000..9006b54ef --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -0,0 +1,35 @@ +# This contains the MediaPipe Hand Landmarker Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more hand categories, using Hand Landmarker. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "hand_landmarker", + srcs = [ + "hand_landmarker.ts", + "hand_landmarker_options.ts", + "hand_landmarker_result.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts new file mode 100644 index 000000000..017a9098c --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -0,0 +1,319 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationList} from '../../../../framework/formats/classification_pb'; +import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; +import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; +import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {HandLandmarkerOptions} from './hand_landmarker_options'; +import {HandLandmarkerResult} from './hand_landmarker_result'; + +export {ImageSource}; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const LANDMARKS_STREAM = 'hand_landmarks'; +const WORLD_LANDMARKS_STREAM = 'world_hand_landmarks'; +const HANDEDNESS_STREAM = 'handedness'; +const HAND_LANDMARKER_GRAPH = + 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'; + +const DEFAULT_NUM_HANDS = 1; +const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CATEGORY_INDEX = -1; +const FULL_IMAGE_RECT = new NormalizedRect(); +FULL_IMAGE_RECT.setXCenter(0.5); +FULL_IMAGE_RECT.setYCenter(0.5); +FULL_IMAGE_RECT.setWidth(1); +FULL_IMAGE_RECT.setHeight(1); + +/** Performs hand landmarks detection on images. */ +export class HandLandmarker extends TaskRunner { + private landmarks: Landmark[][] = []; + private worldLandmarks: Landmark[][] = []; + private handednesses: Category[][] = []; + + private readonly options: HandLandmarkerGraphOptions; + private readonly handLandmarksDetectorGraphOptions: + HandLandmarksDetectorGraphOptions; + private readonly handDetectorGraphOptions: HandDetectorGraphOptions; + + /** + * Initializes the Wasm runtime and creates a new `HandLandmarker` from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param handLandmarkerOptions The options for the HandLandmarker. + * Note that either a path to the model asset or a model buffer needs to + * be provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + handLandmarkerOptions: HandLandmarkerOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load via this mechanism is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const landmarker = await createMediaPipeLib( + HandLandmarker, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await landmarker.setOptions(handLandmarkerOptions); + return landmarker; + } + + /** + * Initializes the Wasm runtime and creates a new `HandLandmarker` based on + * the provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return HandLandmarker.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new `HandLandmarker` based on + * the path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return HandLandmarker.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + constructor(wasmModule: WasmModule) { + super(wasmModule); + + this.options = new HandLandmarkerGraphOptions(); + this.handLandmarksDetectorGraphOptions = + new HandLandmarksDetectorGraphOptions(); + this.options.setHandLandmarksDetectorGraphOptions( + this.handLandmarksDetectorGraphOptions); + this.handDetectorGraphOptions = new HandDetectorGraphOptions(); + this.options.setHandDetectorGraphOptions(this.handDetectorGraphOptions); + + this.initDefaults(); + + // Disables the automatic render-to-screen code, which allows for pure + // CPU processing. + this.setAutoRenderToScreen(false); + } + + /** + * Sets new options for this `HandLandmarker`. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the hand landmarker. + */ + async setOptions(options: HandLandmarkerOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); + this.options.setBaseOptions(baseOptionsProto); + } + + // Configure hand detector options. + if ('numHands' in options) { + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + } + if ('minHandDetectionConfidence' in options) { + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + // Configure hand landmark detector options. + if ('minTrackingConfidence' in options) { + this.options.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + if ('minHandPresenceConfidence' in options) { + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + this.refreshGraph(); + } + + /** + * Performs hand landmarks detection on the provided single image and waits + * synchronously for the response. + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. If not + * provided, defaults to `performance.now()`. + * @return The detected hand landmarks. + */ + detect(imageSource: ImageSource, timestamp: number = performance.now()): + HandLandmarkerResult { + this.landmarks = []; + this.worldLandmarks = []; + this.handednesses = []; + + this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); + this.addProtoToStream( + FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', + NORM_RECT_STREAM, timestamp); + this.finishProcessing(); + + return { + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } + + /** Sets the default values for the graph. */ + private initDefaults(): void { + this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.options.setMinTrackingConfidence(DEFAULT_SCORE_THRESHOLD); + } + + /** Converts the proto data to a Category[][] structure. */ + private toJsCategories(data: Uint8Array[]): Category[][] { + const result: Category[][] = []; + for (const binaryProto of data) { + const inputList = ClassificationList.deserializeBinary(binaryProto); + const outputList: Category[] = []; + for (const classification of inputList.getClassificationList()) { + outputList.push({ + score: classification.getScore() ?? 0, + index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + categoryName: classification.getLabel() ?? '', + displayName: classification.getDisplayName() ?? '', + }); + } + result.push(outputList); + } + return result; + } + + /** Converts raw data into a landmark, and adds it to our landmarks list. */ + private addJsLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handLandmarksProto = + NormalizedLandmarkList.deserializeBinary(binaryProto); + const landmarks: Landmark[] = []; + for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { + landmarks.push({ + x: handLandmarkProto.getX() ?? 0, + y: handLandmarkProto.getY() ?? 0, + z: handLandmarkProto.getZ() ?? 0, + normalized: true + }); + } + this.landmarks.push(landmarks); + } + } + + /** + * Converts raw data into a landmark, and adds it to our worldLandmarks + * list. + */ + private adddJsWorldLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handWorldLandmarksProto = + LandmarkList.deserializeBinary(binaryProto); + const worldLandmarks: Landmark[] = []; + for (const handWorldLandmarkProto of + handWorldLandmarksProto.getLandmarkList()) { + worldLandmarks.push({ + x: handWorldLandmarkProto.getX() ?? 0, + y: handWorldLandmarkProto.getY() ?? 0, + z: handWorldLandmarkProto.getZ() ?? 0, + normalized: false + }); + } + this.worldLandmarks.push(worldLandmarks); + } + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(LANDMARKS_STREAM); + graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM); + graphConfig.addOutputStream(HANDEDNESS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + HandLandmarkerGraphOptions.ext, this.options); + + const landmarkerNode = new CalculatorGraphConfig.Node(); + landmarkerNode.setCalculator(HAND_LANDMARKER_GRAPH); + landmarkerNode.addInputStream('IMAGE:' + IMAGE_STREAM); + landmarkerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + landmarkerNode.addOutputStream('LANDMARKS:' + LANDMARKS_STREAM); + landmarkerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM); + landmarkerNode.addOutputStream('HANDEDNESS:' + HANDEDNESS_STREAM); + landmarkerNode.setOptions(calculatorOptions); + + graphConfig.addNode(landmarkerNode); + + this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.ts new file mode 100644 index 000000000..53ad9440a --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.ts @@ -0,0 +1,47 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +/** Options to configure the MediaPipe HandLandmarker Task */ +export declare interface HandLandmarkerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The maximum number of hands can be detected by the HandLandmarker. + * Defaults to 1. + */ + numHands?: number|undefined; + + /** + * The minimum confidence score for the hand detection to be considered + * successful. Defaults to 0.5. + */ + minHandDetectionConfidence?: number|undefined; + + /** + * The minimum confidence score of hand presence score in the hand landmark + * detection. Defaults to 0.5. + */ + minHandPresenceConfidence?: number|undefined; + + /** + * The minimum confidence score for the hand tracking to be considered + * successful. Defaults to 0.5. + */ + minTrackingConfidence?: number|undefined; +} diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.ts new file mode 100644 index 000000000..044bdfbe7 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.ts @@ -0,0 +1,32 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark} from '../../../../tasks/web/components/containers/landmark'; + +/** + * Represents the hand landmarks deection results generated by `HandLandmarker`. + */ +export declare interface HandLandmarkerResult { + /** Hand landmarks of detected hands. */ + landmarks: Landmark[][]; + + /** Hand landmarks in world coordniates of detected hands. */ + worldLandmarks: Landmark[][]; + + /** Handedness of detected hands. */ + handednesses: Category[][]; +} diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 7cc915f25..2c46dbd3b 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -24,6 +24,11 @@ export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_o export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_result'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +// Hand Landmarker +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker_options'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker_result'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; + // Object Detector export * from '../../../tasks/web/vision/object_detector/object_detector_options'; export * from '../../../tasks/web/vision/object_detector/object_detector_result'; From ca7b5e9d8bbd24dd08717dacb7abec2bd28740b9 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 14 Nov 2022 12:38:23 -0800 Subject: [PATCH 011/469] Fix Script loading PiperOrigin-RevId: 488440736 --- mediapipe/web/graph_runner/wasm_mediapipe_lib.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts index 82a3a3f16..9ecf094ca 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts @@ -971,7 +971,7 @@ async function runScript(scriptUrl: string) { importScripts(scriptUrl.toString()); } else { const script = document.createElement('script'); - script.setAttribute('url', scriptUrl); + script.setAttribute('src', scriptUrl); script.setAttribute('crossorigin', 'anonymous'); return new Promise((resolve) => { script.addEventListener('load', () => { From b4fba6fe6104a943de3b46052255635d84c5d744 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 13:40:49 -0800 Subject: [PATCH 012/469] MediaPipe Tasks AudioEmbedder Java API PiperOrigin-RevId: 488456442 --- .../cc/audio/audio_embedder/audio_embedder.h | 7 +- .../com/google/mediapipe/tasks/audio/BUILD | 30 ++ .../audioclassifier/AudioClassifier.java | 6 +- .../audio/audioembedder/AndroidManifest.xml | 8 + .../audio/audioembedder/AudioEmbedder.java | 388 ++++++++++++++++++ .../audioembedder/AudioEmbedderResult.java | 75 ++++ .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 7 files changed, 511 insertions(+), 4 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h index 4e7e20530..31cb61422 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h @@ -58,9 +58,12 @@ struct AudioEmbedderOptions { nullptr; }; -// Performs embedding extraction on audio clips or audio stream. +// Performs audio embedding extraction on audio clips or audio stream. // -// The API expects a TFLite model with TFLite Model Metadata. +// This API expects a TFLite model with mandatory TFLite Model Metadata that +// contains the mandatory AudioProperties of the solo input audio tensor and the +// optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS per output embedding tensor. // // Input tensor: // (kTfLiteFloat32) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index b162d7dac..6771335ad 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -39,6 +39,7 @@ cc_binary( deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", + "//mediapipe/tasks/cc/audio/audio_embedder:audio_embedder_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], ) @@ -75,6 +76,35 @@ android_library( ], ) +android_library( + name = "audioembedder", + srcs = [ + "audioembedder/AudioEmbedder.java", + "audioembedder/AudioEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "audioembedder/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_audio_aar") mediapipe_tasks_audio_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index 5a82eecaa..0f3374175 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -265,8 +265,10 @@ public final class AudioClassifier extends BaseAudioTaskApi { } /* - * Sends audio data (a block in a continuous audio stream) to perform audio classification. Only - * use this method when the AudioClassifier is created with the audio stream mode. + * Sends audio data (a block in a continuous audio stream) to perform audio classification, and + * the results will be available via the {@link ResultListener} provided in the + * {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with + * the audio stream mode. * *

The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will * be resampled, accumulated, and framed to the proper size for the underlying model to consume. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AndroidManifest.xml new file mode 100644 index 000000000..4cd033db8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java new file mode 100644 index 000000000..c0bc04a4e --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -0,0 +1,388 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.audioembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.audio.audioembedder.proto.AudioEmbedderGraphOptionsProto; +import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; +import com.google.mediapipe.tasks.audio.core.RunningMode; +import com.google.mediapipe.tasks.components.containers.AudioData; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs audio embedding extraction on audio clips or audio stream. + * + *

This API expects a TFLite model with mandatory TFLite Model Metadata that contains the + * mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label + * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + *

Input tensor: (kTfLiteFloat32) + * + *

    + *
  • input audio buffer of size `[batch * samples]`. + *
  • batch inference is not supported (`batch` is required to be 1). + *
  • for multi-channel models, the channels need be interleaved. + *
+ * + *

At least one output tensor with: (kTfLiteFloat32) + * + *

    + *
  • `N` components corresponding to the `N` dimensions of the returned feature vector for this + * output layer. + *
  • Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + *
+ */ +public final class AudioEmbedder extends BaseAudioTaskApi { + private static final String TAG = AudioEmbedder.class.getSimpleName(); + private static final String AUDIO_IN_STREAM_NAME = "audio_in"; + private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "EMBEDDINGS:embeddings_out", "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out")); + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final int TIMESTAMPED_EMBEDDINGS_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph"; + private static final long MICROSECONDS_PER_MILLISECOND = 1000; + + static { + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates an {@link AudioEmbedder} instance from a model file and default {@link + * AudioEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the embedding model in the assets. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link AudioEmbedder} instance from a model file and default {@link + * AudioEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the embedding model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link AudioEmbedder} instance from a model buffer and default {@link + * AudioEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding + * model. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link AudioEmbedder} instance from an {@link AudioEmbedderOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link AudioEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromOptions(Context context, AudioEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public AudioEmbedderResult convertToTaskResult(List packets) { + try { + if (!packets.get(EMBEDDINGS_OUT_STREAM_INDEX).isEmpty()) { + // For audio stream mode. + return AudioEmbedderResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance()), + packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp() + / MICROSECONDS_PER_MILLISECOND); + } else { + // For audio clips mode. + return AudioEmbedderResult.createFromProtoList( + PacketGetter.getProtoVector( + packets.get(TIMESTAMPED_EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.parser()), + -1); + } + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + if (options.resultListener().isPresent()) { + ResultListener resultListener = + new ResultListener() { + @Override + public void run(AudioEmbedderResult audioEmbedderResult, Void input) { + options.resultListener().get().run(audioEmbedderResult); + } + }; + handler.setResultListener(resultListener); + } + options.errorListener().ifPresent(handler::setErrorListener); + // Audio tasks should not drop input audio due to flow limiting, which may cause data + // inconsistency. + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(false) + .build(), + handler); + return new AudioEmbedder(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link AudioEmbedder} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe audio task {@link RunningMode}. + */ + private AudioEmbedder(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME); + } + + /* + * Performs embedding extraction on the provided audio clips. Only use this method when the + * AudioEmbedder is created with the audio clips mode. + * + *

The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts + * audio clips with various length and audio sample rate. It's required to provide the + * corresponding audio sample rate within the {@link AudioData} object. + * + *

The input audio clip may be longer than what the model is able to process in a single + * inference. When this occurs, the input audio clip is split into multiple chunks starting at + * different timestamps. For this reason, this function returns a vector of EmbeddingResult + * objects, each associated with a timestamp corresponding to the start (in milliseconds) of the + * chunk data that was extracted. + * + * @param audioClip a MediaPipe {@link AudioData} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public AudioEmbedderResult embed(AudioData audioClip) { + return (AudioEmbedderResult) processAudioClip(audioClip); + } + + /* + * Sends audio data (a block in a continuous audio stream) to perform audio embedding, and + * the results will be available via the {@link ResultListener} provided in the + * {@link AudioClassifierOptions}. Only use this method when the AudioEmbedder is created with + * the audio stream mode. + * + *

The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will + * be resampled, accumulated, and framed to the proper size for the underlying model to consume. + * It's required to provide the corresponding audio sample rate within {@link AudioData} object as + * well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The + * timestamps must be monotonically increasing. This method will return immediately after + * the input audio data is accepted. The results will be available in the `resultListener` + * provided in the `AudioEmbedderOptions`. The `embedAsync` method is designed to process + * auido stream data such as microphone input. + * + *

The input audio block may be longer than what the model is able to process in a single + * inference. When this occurs, the input audio block is split into multiple chunks. For this + * reason, the callback may be called multiple times (once per chunk) for each call to this + * function. + * + * @param audioBlock a MediaPipe {@link AudioData} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync(AudioData audioBlock, long timestampMs) { + checkOrSetSampleRate(audioBlock.getFormat().getSampleRate()); + sendAudioStreamData(audioBlock, timestampMs); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up and {@link AudioEmbedder}. */ + @AutoValue + public abstract static class AudioEmbedderOptions extends TaskOptions { + + /** Builder for {@link AudioEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the audio embedder task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the audio embedder task. Default to the audio clips mode. + * Image embedder has two modes: + * + *

    + *
  • AUDIO_CLIPS: The mode for running audio embedding on audio clips. Users feed audio + * clips to the `embed` method, and will receive the embedding results as the return + * value. + *
  • AUDIO_STREAM: The mode for running audio embedding on the audio stream, such as from + * microphone. Users call `embedAsync` to push the audio data into the AudioEmbedder, + * the embedding results will be available in the result callback when the audio + * embedder finishes the work. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score + * threshold, number of results, etc. + */ + public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + + /** + * Sets the {@link ResultListener} to receive the embedding results asynchronously when the + * audio embedder is in the audio stream mode. + */ + public abstract Builder setResultListener( + PureResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract AudioEmbedderOptions autoBuild(); + + /** + * Validates and builds the {@link AudioEmbedderOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the audio embedder is + * in the audio stream mode. + */ + public final AudioEmbedderOptions build() { + AudioEmbedderOptions options = autoBuild(); + if (options.runningMode() == RunningMode.AUDIO_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The audio embedder is in the audio stream mode, a user-defined result listener" + + " must be provided in the AudioEmbedderOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The audio embedder is in the audio clips mode, a user-defined result listener" + + " shouldn't be provided in AudioEmbedderOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional embedderOptions(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder() + .setRunningMode(RunningMode.AUDIO_CLIPS); + } + + /** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder = + AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (embedderOptions().isPresent()) { + taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java new file mode 100644 index 000000000..ee4df0198 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java @@ -0,0 +1,75 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.audioembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** Represents the embedding results generated by {@link AudioEmbedder}. */ +@AutoValue +public abstract class AudioEmbedderResult implements TaskResult { + + /** + * Creates an {@link AudioEmbedderResult} instance from a list of {@link + * EmbeddingsProto.EmbeddingResult} protobuf messages. + * + * @param protoList a list of {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static AudioEmbedderResult createFromProtoList( + List protoList, long timestampMs) { + List classificationResultList = new ArrayList<>(); + for (EmbeddingsProto.EmbeddingResult proto : protoList) { + classificationResultList.add(EmbeddingResult.createFromProto(proto)); + } + return new AutoValue_AudioEmbedderResult( + Optional.of(classificationResultList), Optional.empty(), timestampMs); + } + + /** + * Creates an {@link AudioEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static AudioEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return new AutoValue_AudioEmbedderResult( + Optional.empty(), Optional.of(EmbeddingResult.createFromProto(proto)), timestampMs); + } + + /** + * A list of of timpstamped {@link EmbeddingResult} objects, each contains one set of results per + * embedder head. The list represents the audio embedding result of an audio clip, and is only + * available when running with the audio clips mode. + */ + public abstract Optional> embeddingResultList(); + + /** + * Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents + * one audio embedding result in an audio stream, and s only available when running with the audio + * stream mode. + */ + public abstract Optional embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index ab7ad6616..358dd8d10 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -32,6 +32,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ _AUDIO_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", ] _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ From 11270d0c93c456dcee6ed736c5f0f9ed304a8916 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 14 Nov 2022 14:24:11 -0800 Subject: [PATCH 013/469] Image Embedder for Web PiperOrigin-RevId: 488468214 --- mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/core/BUILD | 11 + .../tasks/web/vision/core/running_mode.ts | 36 +++ .../tasks/web/vision/image_embedder/BUILD | 33 +++ .../vision/image_embedder/image_embedder.ts | 214 ++++++++++++++++++ .../image_embedder/image_embedder_options.ts | 31 +++ .../image_embedder/image_embedder_result.ts | 17 ++ mediapipe/tasks/web/vision/index.ts | 5 + 8 files changed, 348 insertions(+) create mode 100644 mediapipe/tasks/web/vision/core/BUILD create mode 100644 mediapipe/tasks/web/vision/core/running_mode.ts create mode 100644 mediapipe/tasks/web/vision/image_embedder/BUILD create mode 100644 mediapipe/tasks/web/vision/image_embedder/image_embedder.ts create mode 100644 mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts create mode 100644 mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 395860892..3c45fbfa6 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -11,6 +11,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", "//mediapipe/tasks/web/vision/object_detector", ], ) diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD new file mode 100644 index 000000000..7ab822b7c --- /dev/null +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -0,0 +1,11 @@ +# This package contains options shared by all MediaPipe Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "running_mode", + srcs = ["running_mode.ts"], + deps = ["//mediapipe/tasks/cc/core/proto:base_options_jspb_proto"], +) diff --git a/mediapipe/tasks/web/vision/core/running_mode.ts b/mediapipe/tasks/web/vision/core/running_mode.ts new file mode 100644 index 000000000..1e9b1b9a7 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/running_mode.ts @@ -0,0 +1,36 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; + +/** + * The running mode of a task. + * 1) The image mode for processing single image inputs. + * 2) The video mode for processing decoded frames of a video. + */ +export type RunningMode = 'image'|'video'; + +/** Configues the `useStreamMode` option . */ +export function configureRunningMode( + options: {runningMode?: RunningMode}, + proto?: BaseOptionsProto): BaseOptionsProto { + proto = proto ?? new BaseOptionsProto(); + if ('runningMode' in options) { + const useStreamMode = options.runningMode === 'video'; + proto.setUseStreamMode(useStreamMode); + } + return proto; +} diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD new file mode 100644 index 000000000..d12a05ad9 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -0,0 +1,33 @@ +# This contains the MediaPipe Image Embedder Task. +# +# This task performs embedding extraction on images. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "image_embedder", + srcs = [ + "image_embedder.ts", + "image_embedder_options.ts", + "image_embedder_result.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:running_mode", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts new file mode 100644 index 000000000..4184e763c --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -0,0 +1,214 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {configureRunningMode} from '../../../../tasks/web/vision/core/running_mode'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {ImageEmbedderOptions} from './image_embedder_options'; +import {ImageEmbedderResult} from './image_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const INPUT_STREAM = 'image_in'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TEXT_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; + +export {ImageSource}; // Used in the public API + +/** Performs embedding extraction on images. */ +export class ImageEmbedder extends TaskRunner { + private readonly options = new ImageEmbedderGraphOptions(); + private embeddings: ImageEmbedderResult = {embeddings: []}; + + /** + * Initializes the Wasm runtime and creates a new image embedder from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param imageEmbedderOptions The options for the image embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + imageEmbedderOptions: ImageEmbedderOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const embedder = await createMediaPipeLib( + ImageEmbedder, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await embedder.setOptions(imageEmbedderOptions); + return embedder; + } + + /** + * Initializes the Wasm runtime and creates a new image embedder based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return ImageEmbedder.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new image embedder based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return ImageEmbedder.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the image embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the image embedder. + */ + async setOptions(options: ImageEmbedderOptions): Promise { + let baseOptionsProto = this.options.getBaseOptions(); + if (options.baseOptions) { + baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, baseOptionsProto); + } + baseOptionsProto = configureRunningMode(options, baseOptionsProto); + this.options.setBaseOptions(baseOptionsProto); + + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + + this.refreshGraph(); + } + + /** + * Performs embedding extraction on the provided image and waits synchronously + * for the response. + * + * Only use this method when the `useStreamMode` option is not set or + * expliclity set to `false`. + * + * @param image The image to process. + * @return The classification result of the image + */ + embed(image: ImageSource): ImageEmbedderResult { + if (!!this.options.getBaseOptions()?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with image mode. ' + + '\'runningMode\' must be set to \'image\'.'); + } + return this.performEmbeddingExtraction(image, performance.now()); + } + + /** + * Performs embedding extraction on the provided video frame and waits + * synchronously for the response. + * + * Only use this method when the `useStreamMode` option is set to `true`. + * + * @param imageFrame The image frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The classification result of the image + */ + embedForVideo(imageFrame: ImageSource, timestamp: number): + ImageEmbedderResult { + if (!this.options.getBaseOptions()?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with video mode. ' + + '\'runningMode\' must be set to \'video\' or \'live_stream\'.'); + } + return this.performEmbeddingExtraction(imageFrame, timestamp); + } + + /** Runs the embedding extractio and blocks on the response. */ + private performEmbeddingExtraction(image: ImageSource, timestamp: number): + ImageEmbedderResult { + // Get embeddings by running our MediaPipe graph. + this.addGpuBufferAsImageToStream( + image, INPUT_STREAM, timestamp ?? performance.now()); + this.finishProcessing(); + return this.embeddings; + } + + /** + * Internal function for converting raw data into an embedding, and setting it + * as our embeddings result. + */ + private addJsImageEmdedding(binaryProto: Uint8Array): void { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddings = convertFromEmbeddingResultProto(embeddingResult); + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension(ImageEmbedderGraphOptions.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('IMAGE:' + INPUT_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.addJsImageEmdedding(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts new file mode 100644 index 000000000..4d795d0d8 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts @@ -0,0 +1,31 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {RunningMode} from '../../../../tasks/web/vision/core/running_mode'; + +/** The options for configuring a MediaPipe image embedder task. */ +export declare interface ImageEmbedderOptions extends EmbedderOptions { + /** + * The running mode of the task. Default to the image mode. + * Image embedder has three running modes: + * 1) The image mode for embedding image on single image inputs. + * 2) The video mode for embedding image on the decoded frames of a video. + * 3) The live stream mode for embedding image on the live stream of input + * data, such as from camera. + */ + runningMode?: RunningMode; +} diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts new file mode 100644 index 000000000..156636505 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Embedding, EmbeddingResult as ImageEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 2c46dbd3b..6dda83e55 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -19,6 +19,11 @@ export * from '../../../tasks/web/vision/image_classifier/image_classifier_optio export * from '../../../tasks/web/vision/image_classifier/image_classifier_result'; export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +// Image Embedder +export * from '../../../tasks/web/vision/image_embedder/image_embedder_options'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder_result'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; + // Gesture Recognizer export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_options'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_result'; From c02737368860667e6cccae959fc2dd2fbcfb7971 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 14 Nov 2022 15:13:20 -0800 Subject: [PATCH 014/469] Internal change PiperOrigin-RevId: 488481286 --- mediapipe/framework/port/build_config.bzl | 39 +++++++++++++++++-- .../tasks/web/components/containers/BUILD | 10 ++--- .../containers/embedding_result.d.ts | 2 +- .../web/components/containers/landmark.d.ts | 2 +- mediapipe/tasks/web/core/BUILD | 16 +++----- .../tasks/web/vision/gesture_recognizer/BUILD | 23 ++++++++--- .../gesture_recognizer/gesture_recognizer.ts | 2 + ...ons.ts => gesture_recognizer_options.d.ts} | 0 ...sult.ts => gesture_recognizer_result.d.ts} | 0 .../tasks/web/vision/hand_landmarker/BUILD | 22 ++++++++--- .../vision/hand_landmarker/hand_landmarker.ts | 2 + ...ptions.ts => hand_landmarker_options.d.ts} | 0 ..._result.ts => hand_landmarker_result.d.ts} | 0 .../tasks/web/vision/image_classifier/BUILD | 22 ++++++++--- .../image_classifier/image_classifier.ts | 2 + ...tions.ts => image_classifier_options.d.ts} | 0 ...result.ts => image_classifier_result.d.ts} | 0 mediapipe/tasks/web/vision/index.ts | 8 ---- .../tasks/web/vision/object_detector/BUILD | 21 +++++++--- .../vision/object_detector/object_detector.ts | 2 + ...ptions.ts => object_detector_options.d.ts} | 0 ..._result.ts => object_detector_result.d.ts} | 0 22 files changed, 121 insertions(+), 52 deletions(-) rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_options.ts => gesture_recognizer_options.d.ts} (100%) rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_result.ts => gesture_recognizer_result.d.ts} (100%) rename mediapipe/tasks/web/vision/hand_landmarker/{hand_landmarker_options.ts => hand_landmarker_options.d.ts} (100%) rename mediapipe/tasks/web/vision/hand_landmarker/{hand_landmarker_result.ts => hand_landmarker_result.d.ts} (100%) rename mediapipe/tasks/web/vision/image_classifier/{image_classifier_options.ts => image_classifier_options.d.ts} (100%) rename mediapipe/tasks/web/vision/image_classifier/{image_classifier_result.ts => image_classifier_result.d.ts} (100%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_options.ts => object_detector_options.d.ts} (100%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_result.ts => object_detector_result.d.ts} (100%) diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index 80e9bfc4d..eaabda856 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -214,10 +214,10 @@ def mediapipe_ts_library( """Generate ts_project for MediaPipe open source version. Args: - name: the name of the cc_proto_library. - srcs: the .proto files of the cc_proto_library for Bazel use. + name: the name of the mediapipe_ts_library. + srcs: the .ts files of the mediapipe_ts_library for Bazel use. visibility: visibility of this target. - deps: a list of dependency labels for Bazel use; must be cc_proto_library. + deps: a list of dependency labels for Bazel use. testonly: test only or not. allow_unoptimized_namespaces: ignored, used only internally """ @@ -235,3 +235,36 @@ def mediapipe_ts_library( declaration = True, tsconfig = "//:tsconfig.json", )) + +def mediapipe_ts_declaration( + name, + srcs, + visibility = None, + deps = []): + """Generate ts_declaration for MediaPipe open source version. + + Args: + name: the name of the mediapipe_ts_declaration. + srcs: the .d.ts files of the mediapipe_ts_declaration for Bazel use. + visibility: visibility of this target. + deps: a list of dependency labels for Bazel use + """ + + # Bazel does not create JS files for .d.ts files, which leads to import + # failures in our open source build. We simply re-name the .d.ts files + # to .ts to work around this problem. + for src in srcs: + native.genrule( + name = replace_suffix(src, ".d.ts", "_d_ts"), + srcs = [src], + outs = [replace_suffix(src, ".d.ts", ".ts")], + visibility = visibility, + cmd = "cp -n $< $@;", + ) + + mediapipe_ts_library( + name = name, + srcs = [replace_suffix(src, ".d.ts", "_d_ts") for src in srcs], + visibility = visibility, + deps = deps, + ) diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index d1bc480db..fb0fdff16 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -1,26 +1,26 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "category", srcs = ["category.d.ts"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "classification_result", srcs = ["classification_result.d.ts"], deps = [":category"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "landmark", srcs = ["landmark.d.ts"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "embedding_result", srcs = ["embedding_result.d.ts"], ) diff --git a/mediapipe/tasks/web/components/containers/embedding_result.d.ts b/mediapipe/tasks/web/components/containers/embedding_result.d.ts index e1efd94ce..3779abd96 100644 --- a/mediapipe/tasks/web/components/containers/embedding_result.d.ts +++ b/mediapipe/tasks/web/components/containers/embedding_result.d.ts @@ -21,7 +21,7 @@ * contain data, based on whether or not the embedder was configured to perform * scalar quantization. */ -export interface Embedding { +export declare interface Embedding { /** * Floating-point embedding. Empty if the embedder was configured to perform * scalar-quantization. diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index f790d8a0b..c887303d0 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -20,7 +20,7 @@ * dimension of image, and the coordinates values are in the range of [0,1]. * Otherwise, it represenet a point in world coordinates. */ -export declare class Landmark { +export declare interface Landmark { /** The x coordinates of the landmark. */ x: number; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index edfc1e5c5..e9ef85d46 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,10 +1,10 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "core", srcs = [ "base_options.d.ts", @@ -24,18 +24,14 @@ mediapipe_ts_library( ], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "classifier_options", - srcs = [ - "classifier_options.d.ts", - ], + srcs = ["classifier_options.d.ts"], deps = [":core"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "embedder_options", - srcs = [ - "embedder_options.d.ts", - ], + srcs = ["embedder_options.d.ts"], deps = [":core"], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 6b99f6ce4..d67974a16 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -3,7 +3,7 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more gesture categories, using Gesture Recognizer. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,12 +11,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", - srcs = [ - "gesture_recognizer.ts", - "gesture_recognizer_options.ts", - "gesture_recognizer_result.ts", - ], + srcs = ["gesture_recognizer.ts"], deps = [ + ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", @@ -38,3 +35,17 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "gesture_recognizer_types", + srcs = [ + "gesture_recognizer_options.d.ts", + "gesture_recognizer_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index c24d1a7b3..6c8072ff5 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -37,6 +37,8 @@ import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../.. import {GestureRecognizerOptions} from './gesture_recognizer_options'; import {GestureRecognizerResult} from './gesture_recognizer_result'; +export * from './gesture_recognizer_options'; +export * from './gesture_recognizer_result'; export {ImageSource}; // The OSS JS API does not support the builder pattern. diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 9006b54ef..25c70e0a5 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -3,7 +3,7 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more hand categories, using Hand Landmarker. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,12 +11,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "hand_landmarker", - srcs = [ - "hand_landmarker.ts", - "hand_landmarker_options.ts", - "hand_landmarker_result.ts", - ], + srcs = ["hand_landmarker.ts"], deps = [ + ":hand_landmarker_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", @@ -33,3 +30,16 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "hand_landmarker_types", + srcs = [ + "hand_landmarker_options.d.ts", + "hand_landmarker_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/core", + ], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 017a9098c..af10305b2 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -33,6 +33,8 @@ import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../.. import {HandLandmarkerOptions} from './hand_landmarker_options'; import {HandLandmarkerResult} from './hand_landmarker_result'; +export * from './hand_landmarker_options'; +export * from './hand_landmarker_result'; export {ImageSource}; // The OSS JS API does not support the builder pattern. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.ts rename to mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.ts rename to mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index e96d6a8e3..8506f3574 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -2,7 +2,7 @@ # # This task takes video or image frames and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,12 +10,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_classifier", - srcs = [ - "image_classifier.ts", - "image_classifier_options.ts", - "image_classifier_result.ts", - ], + srcs = ["image_classifier.ts"], deps = [ + ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", @@ -31,3 +28,16 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "image_classifier_types", + srcs = [ + "image_classifier_options.d.ts", + "image_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index ba4b6c907..5d60e4a21 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -34,6 +34,8 @@ const IMAGE_CLASSIFIER_GRAPH = const INPUT_STREAM = 'input_image'; const CLASSIFICATIONS_STREAM = 'classifications'; +export * from './image_classifier_options'; +export * from './image_classifier_result'; export {ImageSource}; // Used in the public API // The OSS JS API does not support the builder pattern. diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts rename to mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts rename to mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 6dda83e55..0ea844fc9 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -15,8 +15,6 @@ */ // Image Classifier -export * from '../../../tasks/web/vision/image_classifier/image_classifier_options'; -export * from '../../../tasks/web/vision/image_classifier/image_classifier_result'; export * from '../../../tasks/web/vision/image_classifier/image_classifier'; // Image Embedder @@ -25,16 +23,10 @@ export * from '../../../tasks/web/vision/image_embedder/image_embedder_result'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; // Gesture Recognizer -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_options'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_result'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; // Hand Landmarker -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker_options'; -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker_result'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; // Object Detector -export * from '../../../tasks/web/vision/object_detector/object_detector_options'; -export * from '../../../tasks/web/vision/object_detector/object_detector_result'; export * from '../../../tasks/web/vision/object_detector/object_detector'; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 095a84b52..a74dc9211 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -3,7 +3,7 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more object categories, using Object Detector. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,12 +11,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "object_detector", - srcs = [ - "object_detector.ts", - "object_detector_options.ts", - "object_detector_result.ts", - ], + srcs = ["object_detector.ts"], deps = [ + ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:detection_jspb_proto", @@ -28,3 +25,15 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "object_detector_types", + srcs = [ + "object_detector_options.d.ts", + "object_detector_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/core", + ], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 022bf6301..e17a42020 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -33,6 +33,8 @@ const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; const DEFAULT_CATEGORY_INDEX = -1; +export * from './object_detector_options'; +export * from './object_detector_result'; export {ImageSource}; // Used in the public API // The OSS JS API does not support the builder pattern. diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/object_detector/object_detector_options.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/object_detector/object_detector_result.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts From e714e656fe07568290d569478ca308c03d8e6b40 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 15:16:36 -0800 Subject: [PATCH 015/469] Update python documentation. PiperOrigin-RevId: 488482142 --- .../tasks/python/audio/audio_classifier.py | 25 ++++++++++- .../tasks/python/audio/audio_embedder.py | 19 +++++++- .../tasks/python/text/text_classifier.py | 33 +++++++++++++- mediapipe/tasks/python/text/text_embedder.py | 22 +++++++++- .../tasks/python/vision/image_classifier.py | 35 ++++++++++++++- .../tasks/python/vision/image_embedder.py | 19 +++++++- .../tasks/python/vision/image_segmenter.py | 24 +++++++++- .../tasks/python/vision/object_detector.py | 44 ++++++++++++++++++- 8 files changed, 213 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index a081e5ecd..7955cc4dc 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -86,7 +86,30 @@ class AudioClassifierOptions: class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): - """Class that performs audio classification on audio data.""" + """Class that performs audio classification on audio data. + + This API expects a TFLite model with mandatory TFLite Model Metadata that + contains the mandatory AudioProperties of the solo input audio tensor and the + optional (but recommended) category labels as AssociatedFiles with type + TENSOR_AXIS_LABELS per output classification tensor. + + Input tensor: + (kTfLiteFloat32) + - input audio buffer of size `[batch * samples]`. + - batch inference is not supported (`batch` is required to be 1). + - for multi-channel models, the channels must be interleaved. + At least one output tensor with: + (kTfLiteFloat32) + - `[1 x N]` array with `N` represents the number of categories. + - optional (but recommended) category labels as AssociatedFiles with type + TENSOR_AXIS_LABELS, containing one label per line. The first such + AssociatedFile (if any) is used to fill the `category_name` field of the + results. The `display_name` field is filled from the AssociatedFile (if + any) whose locale matches the `display_names_locale` field of the + `AudioClassifierOptions` used at creation time ("en" by default, i.e. + English). If none of these are available, only the `index` field of the + results will be filled. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'AudioClassifier': diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index 98afe490f..a774d71e9 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -87,7 +87,24 @@ class AudioEmbedderOptions: class AudioEmbedder(base_audio_task_api.BaseAudioTaskApi): - """Class that performs embedding extraction on audio clips or audio stream.""" + """Class that performs embedding extraction on audio clips or audio stream. + + This API expects a TFLite model with mandatory TFLite Model Metadata that + contains the mandatory AudioProperties of the solo input audio tensor and the + optional (but recommended) label items as AssociatedFiles with type + TENSOR_AXIS_LABELS per output embedding tensor. + + Input tensor: + (kTfLiteFloat32) + - input audio buffer of size `[batch * samples]`. + - batch inference is not supported (`batch` is required to be 1). + - for multi-channel models, the channels must be interleaved. + At least one output tensor with: + (kTfLiteUInt8/kTfLiteFloat32) + - `N` components corresponding to the `N` dimensions of the returned + feature vector for this output layer. + - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'AudioEmbedder': diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 00f35fada..92d547f20 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -62,7 +62,38 @@ class TextClassifierOptions: class TextClassifier(base_text_task_api.BaseTextTaskApi): - """Class that performs classification on text.""" + """Class that performs classification on text. + + This API expects a TFLite model with (optional) TFLite Model Metadata that + contains the mandatory (described below) input tensors, output tensor, + and the optional (but recommended) category labels as AssociatedFiles with + type + TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for + models with int32 input tensors because it contains the input process unit + for the model's Tokenizer. No metadata is required for models with string + input tensors. + + Input tensors: + (kTfLiteInt32) + - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing + the input ids, segment ids, and mask ids + - or 1 input tensor of size `[batch_size x max_seq_len]` representing the + input ids + or (kTfLiteString) + - 1 input tensor that is shapeless or has shape [1] containing the input + string + At least one output tensor with: + (kTfLiteFloat32/kBool) + - `[1 x N]` array with `N` represents the number of categories. + - optional (but recommended) category labels as AssociatedFiles with type + TENSOR_AXIS_LABELS, containing one label per line. The first such + AssociatedFile (if any) is used to fill the `category_name` field of the + results. The `display_name` field is filled from the AssociatedFile (if + any) whose locale matches the `display_names_locale` field of the + `TextClassifierOptions` used at creation time ("en" by default, i.e. + English). If none of these are available, only the `index` field of the + results will be filled. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'TextClassifier': diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index 2395f6d6b..f3e5eecbe 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -63,7 +63,27 @@ class TextEmbedderOptions: class TextEmbedder(base_text_task_api.BaseTextTaskApi): - """Class that performs embedding extraction on text.""" + """Class that performs embedding extraction on text. + + This API expects a TFLite model with TFLite Model Metadata that contains the + mandatory (described below) input tensors and output tensors. Metadata should + contain the input process unit for the model's Tokenizer as well as input / + output tensor metadata. + + Input tensors: + (kTfLiteInt32) + - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names + "ids", "mask", and "segment_ids" representing the input ids, mask ids, and + segment ids respectively. + - or 1 input tensor of size `[batch_size x max_seq_len]` representing the + input ids. + + At least one output tensor with: + (kTfLiteFloat32) + - `N` components corresponding to the `N` dimensions of the returned + feature vector for this output layer. + - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'TextEmbedder': diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 0537e7dbb..763160e1e 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -87,7 +87,40 @@ class ImageClassifierOptions: class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): - """Class that performs image classification on images.""" + """Class that performs image classification on images. + + The API expects a TFLite model with optional, but strongly recommended, + TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + At least one output tensor with: + (kTfLiteUInt8/kTfLiteFloat32) + - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or + `[1 x 1 x 1 x N]` + - optional (but recommended) label map(s) as AssociatedFiles with type + TENSOR_AXIS_LABELS, containing one label per line. The first such + AssociatedFile (if any) is used to fill the `class_name` field of the + results. The `display_name` field is filled from the AssociatedFile (if + any) whose locale matches the `display_names_locale` field of the + `ImageClassifierOptions` used at creation time ("en" by default, i.e. + English). If none of these are available, only the `index` field of the + results will be filled. + - optional score calibration can be attached using ScoreCalibrationOptions + and an AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See + metadata_schema.fbs [1] for more details. + + An example of such model can be found at: + https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 + + [1]: + https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageClassifier': diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index 922040397..f299fa590 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -86,7 +86,24 @@ class ImageEmbedderOptions: class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): - """Class that performs embedding extraction on images.""" + """Class that performs embedding extraction on images. + + The API expects a TFLite model with optional, but strongly recommended, + TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + At least one output tensor with: + (kTfLiteUInt8/kTfLiteFloat32) + - `N` components corresponding to the `N` dimensions of the returned + feature vector for this output layer. + - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageEmbedder': diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 1740d41ef..9ef911f75 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -93,7 +93,29 @@ class ImageSegmenterOptions: class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): - """Class that performs image segmentation on images.""" + """Class that performs image segmentation on images. + + The API expects a TFLite model with mandatory TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - RGB and greyscale inputs are supported (`channels` is required to be + 1 or 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + Output tensors: + (kTfLiteUInt8/kTfLiteFloat32) + - list of segmented masks. + - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. + - if `output_type` is CONFIDENCE_MASK, float32 Image list of size + `cahnnels`. + - batch is always 1 + + An example of such model can be found at: + https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter': diff --git a/mediapipe/tasks/python/vision/object_detector.py b/mediapipe/tasks/python/vision/object_detector.py index f6177cda2..7c9993d62 100644 --- a/mediapipe/tasks/python/vision/object_detector.py +++ b/mediapipe/tasks/python/vision/object_detector.py @@ -98,7 +98,49 @@ class ObjectDetectorOptions: class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): - """Class that performs object detection on images.""" + """Class that performs object detection on images. + + The API expects a TFLite model with mandatory TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e: + (kTfLiteFloat32) + - locations tensor of size `[num_results x 4]`, the inner array + representing bounding boxes in the form [top, left, right, bottom]. + - BoundingBoxProperties are required to be attached to the metadata + and must specify type=BOUNDARIES and coordinate_type=RATIO. + (kTfLiteFloat32) + - classes tensor of size `[num_results]`, each value representing the + integer index of a class. + - optional (but recommended) label map(s) can be attached as + AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per + line. The first such AssociatedFile (if any) is used to fill the + `class_name` field of the results. The `display_name` field is filled + from the AssociatedFile (if any) whose locale matches the + `display_names_locale` field of the `ObjectDetectorOptions` used at + creation time ("en" by default, i.e. English). If none of these are + available, only the `index` field of the results will be filled. + (kTfLiteFloat32) + - scores tensor of size `[num_results]`, each value representing the score + of the detected object. + - optional score calibration can be attached using ScoreCalibrationOptions + and an AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See + metadata_schema.fbs [1] for more details. + (kTfLiteFloat32) + - integer num_results as a tensor of size `[1]` + + An example of such model can be found at: + https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1 + + [1]: + https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ObjectDetector': From cce1751dbf7d6cd52cc080b312f080f7972150d6 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 15:39:16 -0800 Subject: [PATCH 016/469] Expose tasks components API in PyPI packages. PiperOrigin-RevId: 488487666 --- mediapipe/tasks/python/audio/__init__.py | 6 +++ .../python/components/containers/__init__.py | 40 +++++++++++++++++++ .../python/components/processors/__init__.py | 13 ++++++ mediapipe/tasks/python/text/__init__.py | 6 +++ 4 files changed, 65 insertions(+) diff --git a/mediapipe/tasks/python/audio/__init__.py b/mediapipe/tasks/python/audio/__init__.py index 947f95d9d..e129800a3 100644 --- a/mediapipe/tasks/python/audio/__init__.py +++ b/mediapipe/tasks/python/audio/__init__.py @@ -16,12 +16,18 @@ import mediapipe.tasks.python.audio.core import mediapipe.tasks.python.audio.audio_classifier +import mediapipe.tasks.python.audio.audio_embedder AudioClassifier = audio_classifier.AudioClassifier AudioClassifierOptions = audio_classifier.AudioClassifierOptions +AudioClassifierResult = audio_classifier.AudioClassifierResult +AudioEmbedder = audio_embedder.AudioEmbedder +AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions +AudioEmbedderResult = audio_embedder.AudioEmbedderResult RunningMode = core.audio_task_running_mode.AudioTaskRunningMode # Remove unnecessary modules to avoid duplication in API docs. del audio_classifier +del audio_embedder del core del mediapipe diff --git a/mediapipe/tasks/python/components/containers/__init__.py b/mediapipe/tasks/python/components/containers/__init__.py index 65c1214af..17464db36 100644 --- a/mediapipe/tasks/python/components/containers/__init__.py +++ b/mediapipe/tasks/python/components/containers/__init__.py @@ -11,3 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""MediaPipe Tasks Components Containers API.""" + +import mediapipe.tasks.python.components.containers.audio_data +import mediapipe.tasks.python.components.containers.bounding_box +import mediapipe.tasks.python.components.containers.category +import mediapipe.tasks.python.components.containers.classification_result +import mediapipe.tasks.python.components.containers.detections +import mediapipe.tasks.python.components.containers.embedding_result +import mediapipe.tasks.python.components.containers.landmark +import mediapipe.tasks.python.components.containers.landmark_detection_result +import mediapipe.tasks.python.components.containers.rect + +AudioDataFormat = audio_data.AudioDataFormat +AudioData = audio_data.AudioData +BoundingBox = bounding_box.BoundingBox +Category = category.Category +Classifications = classification_result.Classifications +ClassificationResult = classification_result.ClassificationResult +Detection = detections.Detection +DetectionResult = detections.DetectionResult +Embedding = embedding_result.Embedding +EmbeddingResult = embedding_result.EmbeddingResult +Landmark = landmark.Landmark +NormalizedLandmark = landmark.NormalizedLandmark +LandmarksDetectionResult = landmark_detection_result.LandmarksDetectionResult +Rect = rect.Rect +NormalizedRect = rect.NormalizedRect + +# Remove unnecessary modules to avoid duplication in API docs. +del audio_data +del bounding_box +del category +del classification_result +del detections +del embedding_result +del landmark +del landmark_detection_result +del rect +del mediapipe diff --git a/mediapipe/tasks/python/components/processors/__init__.py b/mediapipe/tasks/python/components/processors/__init__.py index 65c1214af..adcb38757 100644 --- a/mediapipe/tasks/python/components/processors/__init__.py +++ b/mediapipe/tasks/python/components/processors/__init__.py @@ -11,3 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""MediaPipe Tasks Components Processors API.""" + +import mediapipe.tasks.python.components.processors.classifier_options +import mediapipe.tasks.python.components.processors.embedder_options + +ClassifierOptions = classifier_options.ClassifierOptions +EmbedderOptions = embedder_options.EmbedderOptions + +# Remove unnecessary modules to avoid duplication in API docs. +del classifier_options +del embedder_options +del mediapipe diff --git a/mediapipe/tasks/python/text/__init__.py b/mediapipe/tasks/python/text/__init__.py index e2473f56b..ecf3a0ad2 100644 --- a/mediapipe/tasks/python/text/__init__.py +++ b/mediapipe/tasks/python/text/__init__.py @@ -15,10 +15,16 @@ """MediaPipe Tasks Text API.""" import mediapipe.tasks.python.text.text_classifier +import mediapipe.tasks.python.text.text_embedder TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier.TextClassifierOptions +TextClassifierResult = text_classifier.TextClassifierResult +TextEmbedder = text_embedder.TextEmbedder +TextEmbedderOptions = text_embedder.TextEmbedderOptions +TextEmbedderResult = text_embedder.TextEmbedderResult # Remove unnecessary modules to avoid duplication in API docs. del mediapipe del text_classifier +del text_embedder From 794f64db555e6bb1f30bc77bc2343dbc9ec3d72a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 15:49:22 -0800 Subject: [PATCH 017/469] Fix the wrong path of "text_embedder_graph_options_java_proto_lite". PiperOrigin-RevId: 488490050 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 358dd8d10..cae3ccfe9 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -50,7 +50,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/text/text_classifier/proto:text_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", ] def mediapipe_tasks_core_aar(name, srcs, manifest): From a12bc3fd0e947fd8ad5d209ffa3a89dd7c8522b1 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 14 Nov 2022 16:49:58 -0800 Subject: [PATCH 018/469] Add IIFE bundles for NPM packages PiperOrigin-RevId: 488504360 --- mediapipe/tasks/web/BUILD | 69 ++++++++++++++++--- mediapipe/tasks/web/package.json | 7 +- mediapipe/tasks/web/rollup.config.cjs.mjs | 15 ++++ mediapipe/tasks/web/rollup.config.iife.mjs | 21 ++++++ mediapipe/tasks/web/rollup.config.mjs | 9 --- package.json | 2 + yarn.lock | 80 +++++++++++++++++++++- 7 files changed, 181 insertions(+), 22 deletions(-) create mode 100644 mediapipe/tasks/web/rollup.config.cjs.mjs create mode 100644 mediapipe/tasks/web/rollup.config.iife.mjs delete mode 100644 mediapipe/tasks/web/rollup.config.mjs diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 2c0ea57ef..ddc35ab21 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -28,8 +28,8 @@ mediapipe_ts_library( ) rollup_bundle( - name = "audio_bundle", - config_file = "rollup.config.mjs", + name = "audio_cjs_bundle", + config_file = "rollup.config.cjs.mjs", entry_point = "audio.ts", format = "cjs", output_dir = False, @@ -37,6 +37,22 @@ rollup_bundle( ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-replace", + ], +) + +rollup_bundle( + name = "audio_iife_bundle", + config_file = "rollup.config.iife.mjs", + entry_point = "audio.ts", + format = "iife", + output_dir = False, + deps = [ + ":audio_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-replace", + "@npm//@rollup/plugin-terser", ], ) @@ -52,7 +68,8 @@ pkg_npm( deps = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", - ":audio_bundle", + ":audio_cjs_bundle", + ":audio_iife_bundle", ], ) @@ -65,8 +82,8 @@ mediapipe_ts_library( ) rollup_bundle( - name = "text_bundle", - config_file = "rollup.config.mjs", + name = "text_cjs_bundle", + config_file = "rollup.config.cjs.mjs", entry_point = "text.ts", format = "cjs", output_dir = False, @@ -74,6 +91,22 @@ rollup_bundle( ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-replace", + ], +) + +rollup_bundle( + name = "text_iife_bundle", + config_file = "rollup.config.iife.mjs", + entry_point = "text.ts", + format = "iife", + output_dir = False, + deps = [ + ":text_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-replace", + "@npm//@rollup/plugin-terser", ], ) @@ -89,7 +122,8 @@ pkg_npm( deps = [ "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", - ":text_bundle", + ":text_cjs_bundle", + ":text_iife_bundle", ], ) @@ -102,8 +136,8 @@ mediapipe_ts_library( ) rollup_bundle( - name = "vision_bundle", - config_file = "rollup.config.mjs", + name = "vision_cjs_bundle", + config_file = "rollup.config.cjs.mjs", entry_point = "vision.ts", format = "cjs", output_dir = False, @@ -111,6 +145,22 @@ rollup_bundle( ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-replace", + ], +) + +rollup_bundle( + name = "vision_iife_bundle", + config_file = "rollup.config.iife.mjs", + entry_point = "vision.ts", + format = "iife", + output_dir = False, + deps = [ + ":vision_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-replace", + "@npm//@rollup/plugin-terser", ], ) @@ -126,6 +176,7 @@ pkg_npm( deps = [ "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", - ":vision_bundle", + ":vision_cjs_bundle", + ":vision_iife_bundle", ], ) diff --git a/mediapipe/tasks/web/package.json b/mediapipe/tasks/web/package.json index d7d484ca4..5b2dce245 100644 --- a/mediapipe/tasks/web/package.json +++ b/mediapipe/tasks/web/package.json @@ -2,10 +2,11 @@ "name": "@mediapipe/tasks-__NAME__", "version": "__VERSION__", "description": "__DESCRIPTION__", - "main": "__NAME___bundle.js", - "module": "__NAME___bundle.js", + "main": "__NAME___cjs_bundle.js", + "module": "__NAME___cjs_bundle.js", + "jsdeliver": "__NAME___iife_bundle.js", "exports": { - ".": "./__NAME___bundle.js", + ".": "./__NAME___cjs_bundle.js", "./loader": "./wasm/__NAME___wasm_internal.js", "./wasm": "./wasm/__NAME___wasm_internal.wasm" }, diff --git a/mediapipe/tasks/web/rollup.config.cjs.mjs b/mediapipe/tasks/web/rollup.config.cjs.mjs new file mode 100644 index 000000000..5f8ca1848 --- /dev/null +++ b/mediapipe/tasks/web/rollup.config.cjs.mjs @@ -0,0 +1,15 @@ +import resolve from '@rollup/plugin-node-resolve'; +import commonjs from '@rollup/plugin-commonjs'; +import replace from '@rollup/plugin-replace'; + +export default { + plugins: [ + // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 + replace({ + 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', + delimiters: ['', ''] + }), + resolve(), + commonjs() + ] +} diff --git a/mediapipe/tasks/web/rollup.config.iife.mjs b/mediapipe/tasks/web/rollup.config.iife.mjs new file mode 100644 index 000000000..1320927aa --- /dev/null +++ b/mediapipe/tasks/web/rollup.config.iife.mjs @@ -0,0 +1,21 @@ +import resolve from '@rollup/plugin-node-resolve'; +import commonjs from '@rollup/plugin-commonjs'; +import terser from '@rollup/plugin-terser'; +import replace from '@rollup/plugin-replace'; + +export default { + output: { + name: 'bundle', + sourcemap: false + }, + plugins: [ + // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 + replace({ + 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', + delimiters: ['', ''] + }), + resolve({browser: true}), + commonjs(), + terser() + ] +} diff --git a/mediapipe/tasks/web/rollup.config.mjs b/mediapipe/tasks/web/rollup.config.mjs deleted file mode 100644 index 392b235fc..000000000 --- a/mediapipe/tasks/web/rollup.config.mjs +++ /dev/null @@ -1,9 +0,0 @@ -import resolve from '@rollup/plugin-node-resolve'; -import commonjs from '@rollup/plugin-commonjs'; - -export default { - plugins: [ - resolve(), - commonjs() - ] -} diff --git a/package.json b/package.json index 298157cbc..22a035b74 100644 --- a/package.json +++ b/package.json @@ -7,6 +7,8 @@ "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", + "@rollup/plugin-replace": "^5.0.1", + "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", "@types/offscreencanvas": "^2019.7.0", "google-protobuf": "^3.21.2", diff --git a/yarn.lock b/yarn.lock index a5ec6fb13..19c32e322 100644 --- a/yarn.lock +++ b/yarn.lock @@ -31,6 +31,46 @@ dependencies: google-protobuf "^3.6.1" +"@jridgewell/gen-mapping@^0.3.0": + version "0.3.2" + resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz#c1aedc61e853f2bb9f5dfe6d4442d3b565b253b9" + integrity sha512-mh65xKQAzI6iBcFzwv28KVWSmCkdRBWoOh+bYQGW3+6OZvbbN3TqMGo5hqYxQniRcH9F2VZIoJCm4pa3BPDK/A== + dependencies: + "@jridgewell/set-array" "^1.0.1" + "@jridgewell/sourcemap-codec" "^1.4.10" + "@jridgewell/trace-mapping" "^0.3.9" + +"@jridgewell/resolve-uri@3.1.0": + version "3.1.0" + resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz#2203b118c157721addfe69d47b70465463066d78" + integrity sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w== + +"@jridgewell/set-array@^1.0.1": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@jridgewell/set-array/-/set-array-1.1.2.tgz#7c6cf998d6d20b914c0a55a91ae928ff25965e72" + integrity sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw== + +"@jridgewell/source-map@^0.3.2": + version "0.3.2" + resolved "https://registry.yarnpkg.com/@jridgewell/source-map/-/source-map-0.3.2.tgz#f45351aaed4527a298512ec72f81040c998580fb" + integrity sha512-m7O9o2uR8k2ObDysZYzdfhb08VuEml5oWGiosa1VdaPZ/A6QyPkAJuwN0Q1lhULOf6B7MtQmHENS743hWtCrgw== + dependencies: + "@jridgewell/gen-mapping" "^0.3.0" + "@jridgewell/trace-mapping" "^0.3.9" + +"@jridgewell/sourcemap-codec@1.4.14", "@jridgewell/sourcemap-codec@^1.4.10": + version "1.4.14" + resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz#add4c98d341472a289190b424efbdb096991bb24" + integrity sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw== + +"@jridgewell/trace-mapping@^0.3.9": + version "0.3.17" + resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.17.tgz#793041277af9073b0951a7fe0f0d8c4c98c36985" + integrity sha512-MCNzAp77qzKca9+W/+I0+sEpaUnZoeasnghNeVc41VZCEKaCH73Vq3BZZ/SzWIgrqE4H4ceI+p+b6C0mHf9T4g== + dependencies: + "@jridgewell/resolve-uri" "3.1.0" + "@jridgewell/sourcemap-codec" "1.4.14" + "@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2": version "1.1.2" resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf" @@ -108,6 +148,21 @@ is-module "^1.0.0" resolve "^1.22.1" +"@rollup/plugin-replace@^5.0.1": + version "5.0.1" + resolved "https://registry.yarnpkg.com/@rollup/plugin-replace/-/plugin-replace-5.0.1.tgz#49a57af3e6df111a9e75dea3f3572741f4c5c83e" + integrity sha512-Z3MfsJ4CK17BfGrZgvrcp/l6WXoKb0kokULO+zt/7bmcyayokDaQ2K3eDJcRLCTAlp5FPI4/gz9MHAsosz4Rag== + dependencies: + "@rollup/pluginutils" "^5.0.1" + magic-string "^0.26.4" + +"@rollup/plugin-terser@^0.1.0": + version "0.1.0" + resolved "https://registry.yarnpkg.com/@rollup/plugin-terser/-/plugin-terser-0.1.0.tgz#7530c0f11667637419d71820461646c418526041" + integrity sha512-N2KK+qUfHX2hBzVzM41UWGLrEmcjVC37spC8R3c9mt3oEDFKh3N2e12/lLp9aVSt86veR0TQiCNQXrm8C6aiUQ== + dependencies: + terser "^5.15.1" + "@rollup/pluginutils@^5.0.1": version "5.0.2" resolved "https://registry.yarnpkg.com/@rollup/pluginutils/-/pluginutils-5.0.2.tgz#012b8f53c71e4f6f9cb317e311df1404f56e7a33" @@ -165,7 +220,7 @@ acorn-jsx@^5.3.2: resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" integrity sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ== -acorn@^8.8.0: +acorn@^8.5.0, acorn@^8.8.0: version "8.8.1" resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73" integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA== @@ -244,6 +299,11 @@ color-name@~1.1.4: resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== +commander@^2.20.0: + version "2.20.3" + resolved "https://registry.yarnpkg.com/commander/-/commander-2.20.3.tgz#fd485e84c03eb4881c20722ba48035e8531aeb33" + integrity sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ== + commondir@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b" @@ -676,6 +736,14 @@ source-map-support@0.5.9: buffer-from "^1.0.0" source-map "^0.6.0" +source-map-support@~0.5.20: + version "0.5.21" + resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.21.tgz#04fe7c7f9e1ed2d662233c28cb2b35b9f63f6e4f" + integrity sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w== + dependencies: + buffer-from "^1.0.0" + source-map "^0.6.0" + source-map@^0.6.0, source-map@~0.6.1: version "0.6.1" resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" @@ -708,6 +776,16 @@ taffydb@2.6.2: resolved "https://registry.yarnpkg.com/taffydb/-/taffydb-2.6.2.tgz#7cbcb64b5a141b6a2efc2c5d2c67b4e150b2a268" integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA== +terser@^5.15.1: + version "5.15.1" + resolved "https://registry.yarnpkg.com/terser/-/terser-5.15.1.tgz#8561af6e0fd6d839669c73b92bdd5777d870ed6c" + integrity sha512-K1faMUvpm/FBxjBXud0LWVAGxmvoPbZbfTCYbSgaaYQaIXI3/TdI7a7ZGA73Zrou6Q8Zmz3oeUTsp/dj+ag2Xw== + dependencies: + "@jridgewell/source-map" "^0.3.2" + acorn "^8.5.0" + commander "^2.20.0" + source-map-support "~0.5.20" + tmp@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14" From f16e63694e201b9d47cbf3a258f4d4ae7b254ad2 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 17:16:27 -0800 Subject: [PATCH 019/469] Build embedding tasks into tasks AARs. PiperOrigin-RevId: 488509942 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index cae3ccfe9..2b648bc43 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -282,8 +282,12 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", From 9bd8b432c3308fb3728ddbadfc2b06dfcf47c250 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 14 Nov 2022 18:08:29 -0800 Subject: [PATCH 020/469] Add typings support to d.ts package PiperOrigin-RevId: 488519074 --- mediapipe/tasks/web/BUILD | 3 +++ mediapipe/tasks/web/package.json | 1 + 2 files changed, 4 insertions(+) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index ddc35ab21..b8777e785 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -63,6 +63,7 @@ pkg_npm( substitutions = { "__NAME__": "audio", "__DESCRIPTION__": "MediaPipe Audio Tasks", + "__TYPES__": "audio.d.ts", }, tgz = "audio.tgz", deps = [ @@ -117,6 +118,7 @@ pkg_npm( substitutions = { "__NAME__": "text", "__DESCRIPTION__": "MediaPipe Text Tasks", + "__TYPES__": "text.d.ts", }, tgz = "text.tgz", deps = [ @@ -171,6 +173,7 @@ pkg_npm( substitutions = { "__NAME__": "vision", "__DESCRIPTION__": "MediaPipe Vision Tasks", + "__TYPES__": "vision.d.ts", }, tgz = "vision_pkg.tgz", deps = [ diff --git a/mediapipe/tasks/web/package.json b/mediapipe/tasks/web/package.json index 5b2dce245..1870f18a6 100644 --- a/mediapipe/tasks/web/package.json +++ b/mediapipe/tasks/web/package.json @@ -12,6 +12,7 @@ }, "author": "mediapipe@google.com", "license": "Apache-2.0", + "types": "__TYPES__", "dependencies": { "google-protobuf": "^3.21.2" }, From 87dff8142c1eb0d1cf578e14a944d99321843904 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 14 Nov 2022 18:08:43 -0800 Subject: [PATCH 021/469] Fix a typo. PiperOrigin-RevId: 488519113 --- .../tasks/audio/audioclassifier/AudioClassifierResult.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java index fcc3c6e22..3102aa8cd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java @@ -58,7 +58,7 @@ public abstract class AudioClassifierResult implements TaskResult { } /** - * A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results + * A list of of timestamped {@link ClassificationResult} objects, each contains one set of results * per classifier head. The list represents the audio classification result of an audio clip, and * is only available when running with the audio clips mode. */ From 6f54308c257478d9b1a7a98c7cab90079a60e94d Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 14 Nov 2022 21:40:40 -0800 Subject: [PATCH 022/469] Internal change PiperOrigin-RevId: 488552135 --- mediapipe/framework/tool/sink_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/framework/tool/sink_test.cc b/mediapipe/framework/tool/sink_test.cc index 2b5f94f9f..c5316af4d 100644 --- a/mediapipe/framework/tool/sink_test.cc +++ b/mediapipe/framework/tool/sink_test.cc @@ -171,6 +171,7 @@ class TimestampBoundTestCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TimestampBoundTestCalculator); +#if 0 // test is flaky, try it with --runs_per_test=200 TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) { std::string config_str = R"( node { @@ -203,6 +204,7 @@ TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) { EXPECT_THAT(sums, testing::ElementsAre(10, 20)); } +#endif } // namespace } // namespace mediapipe From ebba119f151ec1963eac0b2bda3e10f4cfb7624f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 01:22:38 -0800 Subject: [PATCH 023/469] Add Java ImageEmbedder API. PiperOrigin-RevId: 488588010 --- .../proto/image_embedder_graph_options.proto | 3 + .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + .../com/google/mediapipe/tasks/vision/BUILD | 29 ++ .../vision/imageembedder/AndroidManifest.xml | 8 + .../vision/imageembedder/ImageEmbedder.java | 448 ++++++++++++++++++ .../imageembedder/ImageEmbedderResult.java | 54 +++ .../vision/imageembedder/AndroidManifest.xml | 24 + .../tasks/vision/imageembedder/BUILD | 19 + .../imageembedder/ImageEmbedderTest.java | 444 +++++++++++++++++ 9 files changed, 1030 insertions(+) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 4adba5ab7..72b3e7ee3 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imageembedder.proto"; +option java_outer_classname = "ImageEmbedderGraphOptionsProto"; + message ImageEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional ImageEmbedderGraphOptions ext = 476348187; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 2b648bc43..8b09260bd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -42,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 4dc4a547e..289e3000d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -43,6 +43,7 @@ cc_binary( "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], @@ -172,6 +173,34 @@ android_library( ], ) +android_library( + name = "imageembedder", + srcs = [ + "imageembedder/ImageEmbedder.java", + "imageembedder/ImageEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imageembedder/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") mediapipe_tasks_vision_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..ebdb037d6 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java new file mode 100644 index 000000000..0d8ecd5c3 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -0,0 +1,448 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.proto.ImageEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs embedding extraction on images. + * + *

The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + *

    + *
  • Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + *
      + *
    • image input of size {@code [batch x height x width x channels]}. + *
    • batch inference is not supported ({@code batch} is required to be 1). + *
    • only RGB inputs are supported ({@code channels} is required to be 3). + *
    • if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the + * metadata for input normalization. + *
    + *
  • At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with shape {@code + * [1 x N]} where N is the number of dimensions in the produced embeddings. + *
+ */ +public final class ImageEmbedder extends BaseVisionTaskApi { + private static final String TAG = ImageEmbedder.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out", "IMAGE:image_out")); + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; + + static { + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the embedding model in the assets. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the embedding model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ImageEmbedder} instance from a model buffer and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding + * model. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from an {@link ImageEmbedderOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link ImageEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromOptions(Context context, ImageEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageEmbedderResult convertToTaskResult(List packets) { + try { + return ImageEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + BaseVisionTaskApi.generateResultTimestampMs( + options.runningMode(), packets.get(EMBEDDINGS_OUT_STREAM_INDEX))); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + options.resultListener().ifPresent(handler::setResultListener); + options.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageEmbedder(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageEmbedder} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageEmbedder(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs embedding extraction on the provided single image with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image) { + return embed(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs embedding extraction on the provided single image. Only use this method when the + * {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageEmbedderResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs embedding extraction on the provided video frame with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo(MPImage image, long timestampMs) { + return embedForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs embedding extraction on the provided video frame. Only use this method when the {@link + * ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageEmbedderResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform embedding extraction with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied, and the + * results will be available via the {@link ResultListener} provided in the {@link + * ImageEmbedderOptions}. Only use this method when the {@link ImageEmbedder} is created with + * {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync(MPImage image, long timestampMs) { + embedAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform embedding extraction, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageEmbedderOptions}. Only use this method + * when the {@link ImageEmbedder} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up and {@link ImageEmbedder}. */ + @AutoValue + public abstract static class ImageEmbedderOptions extends TaskOptions { + + /** Builder for {@link ImageEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the image embedder task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the image embedder task. Default to the image mode. Image + * embedder has three modes: + * + *
    + *
  • IMAGE: The mode for performing embedding extraction on single image inputs. + *
  • VIDEO: The mode for performing embedding extraction on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for performing embedding extraction on a live stream of + * input data, such as from camera. In this mode, {@code setResultListener} must be + * called to set up a listener to receive the embedding results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as + * L2-normalization and scalar quantization. + */ + public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + + /** + * Sets the {@link ResultListener} to receive the embedding results asynchronously when the + * image embedder is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract ImageEmbedderOptions autoBuild(); + + /** + * Validates and builds the {@link ImageEmbedderOptions} instance. * + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image embedder is + * in the live stream mode. + */ + public final ImageEmbedderOptions build() { + ImageEmbedderOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the live stream mode, a user-defined result listener" + + " must be provided in the ImageEmbedderOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the image or video mode, a user-defined result listener" + + " shouldn't be provided in ImageEmbedderOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional embedderOptions(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() + .setRunningMode(RunningMode.IMAGE); + } + + /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (embedderOptions().isPresent()) { + taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java new file mode 100644 index 000000000..ee3f4abc9 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link ImageEmbedder}. */ +@AutoValue +public abstract class ImageEmbedderResult implements TaskResult { + + /** + * Creates an {@link ImageEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_ImageEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link ImageEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..db303a439 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java new file mode 100644 index 000000000..56249ead9 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -0,0 +1,444 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.ImageEmbedder.ImageEmbedderOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageEmbedder}/ */ +@RunWith(Suite.class) +@SuiteClasses({ImageEmbedderTest.General.class, ImageEmbedderTest.RunningModeTest.class}) +public class ImageEmbedderTest { + private static final String MOBILENET_EMBEDDER = "mobilenet_v3_small_100_224_embedder.tflite"; + private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_CROP_IMAGE = "burger_crop.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageEmbedderTest { + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedder.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void embed_succeedsWithNoOptions() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithL2Normalization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setEmbedderOptions(embedderOptions) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithQuantization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setEmbedderOptions(embedderOptions) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.926776); + } + + @Test + public void embed_succeedsWithRegionOfInterest() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 0.833333f, 1.0f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); + ImageEmbedderResult resultRoi = + imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoi.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999931f); + } + + @Test + public void embed_succeedsWithRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageEmbedderResult resultRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultRotated.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f); + } + + @Test + public void embed_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger_rotated.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 1.0f, 0.833333f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageEmbedderResult resultRoiRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoiRotated.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageEmbedderTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(mode) + .setResultListener((result, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void embed_failsWithCallingWrongApiInImageMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void embed_succeedsWithImageMode() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithVideoMode() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.VIDEO) + .build(); + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + for (int i = 0; i < 3; ++i) { + ImageEmbedderResult result = + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i); + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + } + } + + @Test + public void embed_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void embed_succeedsWithLiveStreamMode() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; ++i) { + imageEmbedder.embedAsync(image, /*timestampMs=*/ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertHasOneHeadAndCorrectDimension( + ImageEmbedderResult result, boolean quantized) { + assertThat(result.embeddingResult().embeddings()).hasSize(1); + assertThat(result.embeddingResult().embeddings().get(0).headIndex()).isEqualTo(0); + assertThat(result.embeddingResult().embeddings().get(0).headName().get()).isEqualTo("feature"); + if (quantized) { + assertThat(result.embeddingResult().embeddings().get(0).quantizedEmbedding()).hasLength(1024); + } else { + assertThat(result.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(1024); + } + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(480); + assertThat(inputImage.getHeight()).isEqualTo(325); + } +} From f14645cb06376cd1a6818a6155118ad0667d2d84 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 10:48:41 -0800 Subject: [PATCH 024/469] Model maker gesture recognizer test changes PiperOrigin-RevId: 488702055 --- .../gesture_recognizer_test.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index eb2b1d171..7e7a1ca30 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import tempfile from unittest import mock as unittest_mock import zipfile @@ -40,30 +41,35 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - self._model_options = gesture_recognizer.ModelOptions() - self._hparams = gesture_recognizer.HParams(epochs=2) - self._gesture_recognizer_options = ( - gesture_recognizer.GestureRecognizerOptions( - model_options=self._model_options, hparams=self._hparams)) all_data = self._load_data() # Splits data, 90% data for training, 10% for testing self._train_data, self._test_data = all_data.split(0.9) def test_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) self._test_accuracy(model) def test_export_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) model.export_model() - model_bundle_file = os.path.join(self._hparams.export_dir, + model_bundle_file = os.path.join(hparams.export_dir, 'gesture_recognizer.task') with zipfile.ZipFile(model_bundle_file) as zf: self.assertEqual( @@ -102,7 +108,7 @@ class GestureRecognizerTest(tf.test.TestCase): 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) - def test_create_hparams_and_model_options_if_none_in_image_classifier_options( + def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( self, mock_hparams, mock_model_options): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( @@ -113,16 +119,21 @@ class GestureRecognizerTest(tf.test.TestCase): mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) self._test_accuracy(model) self.assertRegex(mock_stdout.getvalue(), 'Resuming from') From a94564540bc22af9d02c4df3102a1f0d3424929e Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 15 Nov 2022 11:49:21 -0800 Subject: [PATCH 025/469] Bump up the dependency library pybind11's version to 2.10.1. PiperOrigin-RevId: 488718815 --- WORKSPACE | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 702d1899e..fea96d941 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -212,14 +212,14 @@ http_archive( sha256 = "75922da3a1bdb417d820398eb03d4e9bd067c4905a4246d35a44c01d62154d91", ) -# Point to the commit that deprecates the usage of Eigen::MappedSparseMatrix. +# 2022-10-20 http_archive( name = "pybind11", urls = [ - "https://github.com/pybind/pybind11/archive/70a58c577eaf067748c2ec31bfd0b0a614cffba6.zip", + "https://github.com/pybind/pybind11/archive/v2.10.1.zip", ], - sha256 = "b971842fab1b5b8f3815a2302331782b7d137fef0e06502422bc4bc360f4956c", - strip_prefix = "pybind11-70a58c577eaf067748c2ec31bfd0b0a614cffba6", + sha256 = "fcf94065efcfd0a7a828bacf118fa11c43f6390d0c805e3e6342ac119f2e9976", + strip_prefix = "pybind11-2.10.1", build_file = "@pybind11_bazel//:pybind11.BUILD", ) From 1689112b23fc6038114a143baf0253e0b6c043c6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 14:02:21 -0800 Subject: [PATCH 026/469] Improve model_util_test code. PiperOrigin-RevId: 488752497 --- .../model_maker/python/core/utils/model_util_test.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index bef9c8a97..05c6ffe3f 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from typing import Optional from absl.testing import parameterized import tensorflow as tf @@ -76,8 +77,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]), expected_steps_per_epoch=2)) - def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data, - expected_steps_per_epoch): + def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int], + batch_size: Optional[int], + train_data: Optional[tf.data.Dataset], + expected_steps_per_epoch: int): estimated_steps_per_epoch = model_util.get_steps_per_epoch( steps_per_epoch=steps_per_epoch, batch_size=batch_size, @@ -130,7 +133,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): testcase_name='float16_quantize', config=quantization.QuantizationConfig.for_float16(), model_size=1468)) - def test_convert_to_tflite_quantized(self, config, model_size): + def test_convert_to_tflite_quantized(self, + config: quantization.QuantizationConfig, + model_size: int): input_dim = 16 num_classes = 2 max_input_value = 5 @@ -157,5 +162,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): test_util.test_tflite_file( keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) + if __name__ == '__main__': tf.test.main() From 496720308c66d02832038090e1a6562ca5b6342f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 15 Nov 2022 14:03:17 -0800 Subject: [PATCH 027/469] Migrate remaining MP Tasks Libraries to ts_declarations PiperOrigin-RevId: 488752799 --- .../tasks/web/audio/audio_classifier/BUILD | 23 ++++++++++++++----- .../audio_classifier/audio_classifier.ts | 3 +++ ...tions.ts => audio_classifier_options.d.ts} | 0 ...result.ts => audio_classifier_result.d.ts} | 0 mediapipe/tasks/web/audio/index.ts | 3 --- mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/index.ts | 4 +--- .../tasks/web/text/text_classifier/BUILD | 22 +++++++++++++----- .../text/text_classifier/text_classifier.ts | 3 +++ ...ptions.ts => text_classifier_options.d.ts} | 0 ..._result.ts => text_classifier_result.d.ts} | 0 mediapipe/tasks/web/text/text_embedder/BUILD | 23 ++++++++++++++----- .../web/text/text_embedder/text_embedder.ts | 2 ++ .../tasks/web/vision/image_embedder/BUILD | 23 ++++++++++++++----- .../vision/image_embedder/image_embedder.ts | 2 ++ ...options.ts => image_embedder_options.d.ts} | 0 ...r_result.ts => image_embedder_result.d.ts} | 0 mediapipe/tasks/web/vision/index.ts | 11 --------- 18 files changed, 79 insertions(+), 41 deletions(-) rename mediapipe/tasks/web/audio/audio_classifier/{audio_classifier_options.ts => audio_classifier_options.d.ts} (100%) rename mediapipe/tasks/web/audio/audio_classifier/{audio_classifier_result.ts => audio_classifier_result.d.ts} (100%) rename mediapipe/tasks/web/text/text_classifier/{text_classifier_options.ts => text_classifier_options.d.ts} (100%) rename mediapipe/tasks/web/text/text_classifier/{text_classifier_result.ts => text_classifier_result.d.ts} (100%) rename mediapipe/tasks/web/vision/image_embedder/{image_embedder_options.ts => image_embedder_options.d.ts} (100%) rename mediapipe/tasks/web/vision/image_embedder/{image_embedder_result.ts => image_embedder_result.d.ts} (100%) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 1bc4af309..6a78116c3 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,7 +2,7 @@ # # This task takes audio data and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,12 +10,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", - srcs = [ - "audio_classifier.ts", - "audio_classifier_options.ts", - "audio_classifier_result.ts", - ], + srcs = ["audio_classifier.ts"], deps = [ + ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", @@ -31,3 +28,17 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "audio_classifier_types", + srcs = [ + "audio_classifier_options.d.ts", + "audio_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index e3700cd7a..76b926723 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -29,6 +29,9 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm import {AudioClassifierOptions} from './audio_classifier_options'; import {AudioClassifierResult} from './audio_classifier_result'; +export * from './audio_classifier_options'; +export * from './audio_classifier_result'; + const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts similarity index 100% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts rename to mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts rename to mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 114a8ceca..a5083b326 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,7 +14,4 @@ * limitations under the License. */ -// Audio Classifier -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_options'; -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_result'; export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index a369d0af0..4b465b0f5 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -9,5 +9,6 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", ], ) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index dc511a426..d50db209c 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,7 +14,5 @@ * limitations under the License. */ -// Text Classifier -export * from '../../../tasks/web/text/text_classifier/text_classifier_options'; -export * from '../../../tasks/web/text/text_classifier/text_classifier_result'; export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 4ebdce18a..7dbbb18ca 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -3,7 +3,7 @@ # This task takes text input performs Natural Language classification (including # BERT-based text classification). -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,12 +11,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", - srcs = [ - "text_classifier.ts", - "text_classifier_options.ts", - "text_classifier_result.ts", - ], + srcs = ["text_classifier.ts"], deps = [ + ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", @@ -32,3 +29,16 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "text_classifier_types", + srcs = [ + "text_classifier_options.d.ts", + "text_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index e1d0c9601..d4f413efa 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -29,6 +29,9 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm import {TextClassifierOptions} from './text_classifier_options'; import {TextClassifierResult} from './text_classifier_result'; +export * from './text_classifier_options'; +export * from './text_classifier_result'; + const INPUT_STREAM = 'text_in'; const CLASSIFICATIONS_STREAM = 'classifications_out'; const TEXT_CLASSIFIER_GRAPH = diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 8e397ce6f..bebd612dd 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -3,7 +3,7 @@ # This task takes text input and performs embedding # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,13 +11,11 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_embedder", - srcs = [ - "text_embedder.ts", - "text_embedder_options.d.ts", - "text_embedder_result.d.ts", - ], + srcs = ["text_embedder.ts"], deps = [ + ":text_embedder_types", "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", @@ -30,3 +28,16 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "text_embedder_types", + srcs = [ + "text_embedder_options.d.ts", + "text_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 65df5df6a..7c631683d 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -29,6 +29,8 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm import {TextEmbedderOptions} from './text_embedder_options'; import {TextEmbedderResult} from './text_embedder_result'; +export * from './text_embedder_options'; +export * from './text_embedder_result'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index d12a05ad9..13ff2e4d6 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -2,7 +2,7 @@ # # This task performs embedding extraction on images. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,12 +10,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_embedder", - srcs = [ - "image_embedder.ts", - "image_embedder_options.ts", - "image_embedder_result.ts", - ], + srcs = ["image_embedder.ts"], deps = [ + ":image_embedder_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", @@ -31,3 +28,17 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "image_embedder_types", + srcs = [ + "image_embedder_options.d.ts", + "image_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:running_mode", + ], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 4184e763c..91d9b5119 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -38,6 +38,8 @@ const EMBEDDINGS_STREAM = 'embeddings_out'; const TEXT_EMBEDDER_CALCULATOR = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; +export * from './image_embedder_options'; +export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts rename to mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts rename to mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 0ea844fc9..d68c00cc7 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,19 +14,8 @@ * limitations under the License. */ -// Image Classifier export * from '../../../tasks/web/vision/image_classifier/image_classifier'; - -// Image Embedder -export * from '../../../tasks/web/vision/image_embedder/image_embedder_options'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder_result'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; - -// Gesture Recognizer export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; - -// Hand Landmarker export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; - -// Object Detector export * from '../../../tasks/web/vision/object_detector/object_detector'; From e65f21e2d85f9f08097e953ed9948de481065024 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 14:34:45 -0800 Subject: [PATCH 028/469] Update the docstring to make it consistent with the model option update. PiperOrigin-RevId: 488761331 --- .../python/vision/image_classifier/image_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 1ff6132b4..df71a8fef 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -177,7 +177,7 @@ class ImageClassifier(classifier.Classifier): Args: model_name: File name to save TFLite model with metadata. The full export - path is {self._hparams.model_dir}/{model_name}. + path is {self._hparams.export_dir}/{model_name}. quantization_config: The configuration for model quantization. """ if not tf.io.gfile.exists(self._hparams.export_dir): From 7a87546c30c347f8fc8d046431dbb27208a0f920 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 14:35:54 -0800 Subject: [PATCH 029/469] Internal change PiperOrigin-RevId: 488761646 --- mediapipe/framework/tool/test_util.cc | 22 +++++++++++++--------- mediapipe/framework/tool/test_util.h | 4 ++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 6433c93d2..c7ed063e0 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -258,11 +258,8 @@ std::string GetTestFilePath(absl::string_view relative_path) { return file::JoinPath(GetTestRootDir(), relative_path); } -absl::StatusOr> LoadTestImage( - absl::string_view path, ImageFormat::Format format) { - std::string encoded; - MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); - +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format) { // stbi_load determines the output pixel format based on the desired channels. // 0 means "use whatever's in the file". int desired_channels = format == ImageFormat::UNKNOWN ? 0 @@ -274,10 +271,10 @@ absl::StatusOr> LoadTestImage( << "unsupported output format requested: " << format; int width, height, channels_in_file; - auto data = stbi_load_from_memory(reinterpret_cast(encoded.data()), - encoded.size(), &width, &height, - &channels_in_file, desired_channels); - RET_CHECK(data) << "failed to decode image data from: " << path; + auto data = stbi_load_from_memory( + reinterpret_cast(encoded.data()), encoded.size(), &width, + &height, &channels_in_file, desired_channels); + RET_CHECK(data) << "failed to decode image data"; // If we didn't specify a desired format, it will be determined by what the // file contains. @@ -295,6 +292,13 @@ absl::StatusOr> LoadTestImage( format, width, height, width * output_channels, data, stbi_image_free); } +absl::StatusOr> LoadTestImage( + absl::string_view path, ImageFormat::Format format) { + std::string encoded; + MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); + return DecodeTestImage(encoded, format); +} + std::unique_ptr LoadTestPng(absl::string_view path, ImageFormat::Format format) { return nullptr; diff --git a/mediapipe/framework/tool/test_util.h b/mediapipe/framework/tool/test_util.h index 71c096db7..80b768e3d 100644 --- a/mediapipe/framework/tool/test_util.h +++ b/mediapipe/framework/tool/test_util.h @@ -81,6 +81,10 @@ std::string GetTestDataDir(absl::string_view package_base_path); // Loads a binary graph from path. Returns true iff successful. bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path); +// Loads an image from memory. +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format = ImageFormat::SRGBA); + // Loads an image from path. absl::StatusOr> LoadTestImage( absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA); From 38b636f7ee6c952832bc869475d47a1bf5e1c453 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 15:10:36 -0800 Subject: [PATCH 030/469] Internal change PiperOrigin-RevId: 488770794 --- mediapipe/framework/deps/BUILD | 1 + mediapipe/framework/deps/registration.h | 39 +++++++++++++------------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index a39d7476e..95ab21707 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -225,6 +225,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index b39a1e293..1a33b2b24 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -26,10 +26,12 @@ #include "absl/base/macros.h" #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/deps/registration_token.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -159,7 +161,7 @@ class FunctionRegistry { FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete; - RegistrationToken Register(const std::string& name, Function func) + RegistrationToken Register(absl::string_view name, Function func) ABSL_LOCKS_EXCLUDED(lock_) { std::string normalized_name = GetNormalizedName(name); absl::WriterMutexLock lock(&lock_); @@ -189,14 +191,15 @@ class FunctionRegistry { absl::enable_if_t, std::tuple>::value, int> = 0> - ReturnType Invoke(const std::string& name, Args2&&... args) + ReturnType Invoke(absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { Function function; { absl::ReaderMutexLock lock(&lock_); auto it = functions_.find(name); if (it == functions_.end()) { - return absl::NotFoundError("No registered object with name: " + name); + return absl::NotFoundError( + absl::StrCat("No registered object with name: ", name)); } function = it->second; } @@ -206,7 +209,7 @@ class FunctionRegistry { // Invokes the specified factory function and returns the result. // Namespaces in |name| and |ns| are separated by kNameSep. template - ReturnType Invoke(const std::string& ns, const std::string& name, + ReturnType Invoke(absl::string_view ns, absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { return Invoke(GetQualifiedName(ns, name), args...); } @@ -214,14 +217,14 @@ class FunctionRegistry { // Note that it's possible for registered implementations to be subsequently // unregistered, though this will never happen with registrations made via // MEDIAPIPE_REGISTER_FACTORY_FUNCTION. - bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) { + bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { absl::ReaderMutexLock lock(&lock_); return functions_.count(name) != 0; } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - bool IsRegistered(const std::string& ns, const std::string& name) const + bool IsRegistered(absl::string_view ns, absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { return IsRegistered(GetQualifiedName(ns, name)); } @@ -244,7 +247,7 @@ class FunctionRegistry { // Normalizes a C++ qualified name. Validates the name qualification. // The name must be either unqualified or fully qualified with a leading "::". // The leading "::" in a fully qualified name is stripped. - std::string GetNormalizedName(const std::string& name) { + std::string GetNormalizedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); if (names[0].empty()) { @@ -259,8 +262,8 @@ class FunctionRegistry { // Returns the registry key for a name specified within a namespace. // Namespaces are separated by kNameSep. - std::string GetQualifiedName(const std::string& ns, - const std::string& name) const { + std::string GetQualifiedName(absl::string_view ns, + absl::string_view name) const { using ::mediapipe::registration_internal::kCxxSep; using ::mediapipe::registration_internal::kNameSep; std::vector names = absl::StrSplit(name, kNameSep); @@ -287,10 +290,10 @@ class FunctionRegistry { private: mutable absl::Mutex lock_; - std::unordered_map functions_ ABSL_GUARDED_BY(lock_); + absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); // For names included in NamespaceAllowlist, strips the namespace. - std::string GetAdjustedName(const std::string& name) { + std::string GetAdjustedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); std::string base_name = names.back(); @@ -299,10 +302,10 @@ class FunctionRegistry { if (NamespaceAllowlist::TopNamespaces().count(ns)) { return base_name; } - return name; + return std::string(name); } - void Unregister(const std::string& name) { + void Unregister(absl::string_view name) { absl::WriterMutexLock lock(&lock_); std::string adjusted_name = GetAdjustedName(name); if (adjusted_name != name) { @@ -317,7 +320,7 @@ class GlobalFactoryRegistry { using Functions = FunctionRegistry; public: - static RegistrationToken Register(const std::string& name, + static RegistrationToken Register(absl::string_view name, typename Functions::Function func) { return functions()->Register(name, std::move(func)); } @@ -326,7 +329,7 @@ class GlobalFactoryRegistry { // If using namespaces with this registry, the variant with a namespace // argument should be used. template - static typename Functions::ReturnType CreateByName(const std::string& name, + static typename Functions::ReturnType CreateByName(absl::string_view name, Args2&&... args) { return functions()->Invoke(name, std::forward(args)...); } @@ -334,7 +337,7 @@ class GlobalFactoryRegistry { // Returns true if the specified factory function is available. // If using namespaces with this registry, the variant with a namespace // argument should be used. - static bool IsRegistered(const std::string& name) { + static bool IsRegistered(absl::string_view name) { return functions()->IsRegistered(name); } @@ -350,13 +353,13 @@ class GlobalFactoryRegistry { std::tuple>::value, int> = 0> static typename Functions::ReturnType CreateByNameInNamespace( - const std::string& ns, const std::string& name, Args2&&... args) { + absl::string_view ns, absl::string_view name, Args2&&... args) { return functions()->Invoke(ns, name, std::forward(args)...); } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - static bool IsRegistered(const std::string& ns, const std::string& name) { + static bool IsRegistered(absl::string_view ns, absl::string_view name) { return functions()->IsRegistered(ns, name); } From a67069156e8d42f18403d5c47aa6219f4379b00d Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:16:11 -0800 Subject: [PATCH 031/469] Use flat_hash_map in ResourceCache This is the recommended hashmap in most cases. PiperOrigin-RevId: 488772031 --- mediapipe/util/BUILD | 1 + mediapipe/util/resource_cache.h | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index ab3390e0a..15835aea5 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -228,6 +228,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", ], ) diff --git a/mediapipe/util/resource_cache.h b/mediapipe/util/resource_cache.h index 4cd869f6a..2b3ccbc7d 100644 --- a/mediapipe/util/resource_cache.h +++ b/mediapipe/util/resource_cache.h @@ -17,6 +17,7 @@ #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/function_ref.h" #include "mediapipe/framework/port/logging.h" @@ -26,7 +27,8 @@ namespace mediapipe { // resource (e.g., image dimension for an image pool) is described bye the `Key` // type. The `Value` type must include an unset value, with implicit conversion // to bool reflecting set/unset state. -template +template ::hasher> class ResourceCache { public: Value Lookup( @@ -36,15 +38,14 @@ class ResourceCache { Entry* entry; if (map_it == map_.end()) { std::tie(map_it, std::ignore) = - map_.emplace(std::piecewise_construct, std::forward_as_tuple(key), - std::forward_as_tuple(key)); - entry = &map_it->second; + map_.try_emplace(key, std::make_unique(key)); + entry = map_it->second.get(); CHECK_EQ(entry->request_count, 0); entry->request_count = 1; entry_list_.Append(entry); if (entry->prev != nullptr) CHECK_GE(entry->prev->request_count, 1); } else { - entry = &map_it->second; + entry = map_it->second.get(); ++entry->request_count; Entry* larger = entry->prev; while (larger != nullptr && @@ -171,7 +172,7 @@ class ResourceCache { size_t size_ = 0; }; - std::unordered_map map_; + absl::flat_hash_map, KeyHash> map_; EntryList entry_list_; int total_request_count_ = 0; }; From 3c71c64be12409ed2019ac16a02263d3ebf96335 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:30:59 -0800 Subject: [PATCH 032/469] Remove shared_ptr from SimplePool definition This makes the types more explicit and will help with factoring out platform-specific code. PiperOrigin-RevId: 488775470 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 26 +++++++++++++------------- mediapipe/gpu/gpu_buffer_multi_pool.h | 13 +++++++------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 6e4fd38ea..5e8ce06b9 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -90,8 +90,8 @@ std::string CvPixelBufferPoolWrapper::GetDebugString() const { void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec) { +std::shared_ptr +GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { return std::make_shared(spec, kMaxInactiveBufferAge); } @@ -123,7 +123,7 @@ void GpuBufferMultiPool::FlushTextureCaches() { #define FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR 0 GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { + BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { #if TARGET_IPHONE_SIMULATOR && FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR // On the simulator, syncing the texture with the pixelbuffer does not work, // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not @@ -134,14 +134,14 @@ GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( // pool to give us contiguous data. return GetBufferWithoutPool(spec); #else - return pool->GetBuffer([this]() { FlushTextureCaches(); }); + return pool.GetBuffer([this]() { FlushTextureCaches(); }); #endif // TARGET_IPHONE_SIMULATOR } #else -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const BufferSpec& spec) { +std::shared_ptr +GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, kKeepCount); } @@ -152,16 +152,16 @@ GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { } GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool->GetBuffer()); + BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { + return GpuBuffer(pool.GetBuffer()); } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::RequestPool( +std::shared_ptr GpuBufferMultiPool::RequestPool( const BufferSpec& spec) { - SimplePool pool; - std::vector evicted; + std::shared_ptr pool; + std::vector> evicted; { absl::MutexLock lock(&mutex_); pool = @@ -180,10 +180,10 @@ GpuBufferMultiPool::SimplePool GpuBufferMultiPool::RequestPool( GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, GpuBufferFormat format) { BufferSpec key(width, height, format); - SimplePool pool = RequestPool(key); + std::shared_ptr pool = RequestPool(key); if (pool) { // Note: we release our multipool lock before accessing the simple pool. - return GetBufferFromSimplePool(key, pool); + return GetBufferFromSimplePool(key, *pool); } else { return GetBufferWithoutPool(key); } diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 5ea6e314f..287b3b2a7 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -83,22 +83,23 @@ class GpuBufferMultiPool { private: #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - using SimplePool = std::shared_ptr; + using SimplePool = CvPixelBufferPoolWrapper; #else - using SimplePool = std::shared_ptr; + using SimplePool = GlTextureBufferPool; #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - SimplePool MakeSimplePool(const BufferSpec& spec); + std::shared_ptr MakeSimplePool(const BufferSpec& spec); // Requests a simple buffer pool for the given spec. This may return nullptr // if we have not yet reached a sufficient number of requests to allocate a // pool, in which case the caller should invoke GetBufferWithoutPool instead // of GetBufferFromSimplePool. - SimplePool RequestPool(const BufferSpec& spec); - GpuBuffer GetBufferFromSimplePool(BufferSpec spec, const SimplePool& pool); + std::shared_ptr RequestPool(const BufferSpec& spec); + GpuBuffer GetBufferFromSimplePool(BufferSpec spec, SimplePool& pool); GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); absl::Mutex mutex_; - mediapipe::ResourceCache> + mediapipe::ResourceCache, + absl::Hash> cache_ ABSL_GUARDED_BY(mutex_); #ifdef __APPLE__ From a520d6cc38dd13c68bf7fac24a919ec8b0bfcdfe Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:39:41 -0800 Subject: [PATCH 033/469] Remove FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR This workaround code is no longer necessary, as per the comment. PiperOrigin-RevId: 488777606 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 5e8ce06b9..2bceb1c05 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -117,25 +117,9 @@ void GpuBufferMultiPool::FlushTextureCaches() { } } -// Turning this on disables the pixel buffer pools when using the simulator. -// It is no longer necessary, since the helper code now supports non-contiguous -// buffers. We leave the code in for now for the sake of documentation. -#define FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR 0 - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { -#if TARGET_IPHONE_SIMULATOR && FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR - // On the simulator, syncing the texture with the pixelbuffer does not work, - // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not - // available in OpenGL ES 2, we should create the buffer so the pixels are - // contiguous. - // - // TODO: verify if we can use kIOSurfaceBytesPerRow to force the - // pool to give us contiguous data. - return GetBufferWithoutPool(spec); -#else return pool.GetBuffer([this]() { FlushTextureCaches(); }); -#endif // TARGET_IPHONE_SIMULATOR } #else From fae55910f44370b86bd04f0cea106cec43be5374 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 15:56:36 -0800 Subject: [PATCH 034/469] Enable absl::string_view kCalculatorName PiperOrigin-RevId: 488781493 --- mediapipe/framework/api2/builder.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 6d3323b97..19273bf44 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -412,11 +412,11 @@ using GenericNode = Node; template class Node : public NodeBase { public: - Node() : NodeBase(Calc::kCalculatorName) {} + Node() : NodeBase(std::string(Calc::kCalculatorName)) {} // Overrides the built-in calculator type string with the provided argument. // Can be used to create nodes from pure interfaces. // TODO: only use this for pure interfaces - Node(const std::string& type_override) : NodeBase(type_override) {} + Node(std::string type_override) : NodeBase(std::move(type_override)) {} // These methods only allow access to ports declared in the contract. // The argument must be a tag object created with the MPP_TAG macro. From ab2dd779e73a6756bce09d107fc9a738d9e09edd Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:57:43 -0800 Subject: [PATCH 035/469] Factor out CvTextureCacheManager This is a platform-specific component that is only used with CVPixelBufferPool. PiperOrigin-RevId: 488781757 --- mediapipe/gpu/BUILD | 16 +++++++ mediapipe/gpu/cv_texture_cache_manager.cc | 55 +++++++++++++++++++++++ mediapipe/gpu/cv_texture_cache_manager.h | 49 ++++++++++++++++++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 40 +---------------- mediapipe/gpu/gpu_buffer_multi_pool.h | 28 +++--------- mediapipe/gpu/gpu_shared_data_internal.cc | 19 +++++--- mediapipe/gpu/gpu_shared_data_internal.h | 3 ++ 7 files changed, 143 insertions(+), 67 deletions(-) create mode 100644 mediapipe/gpu/cv_texture_cache_manager.cc create mode 100644 mediapipe/gpu/cv_texture_cache_manager.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9c2f47469..93527b565 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -344,6 +344,18 @@ cc_library( ], ) +cc_library( + name = "cv_texture_cache_manager", + srcs = ["cv_texture_cache_manager.cc"], + hdrs = ["cv_texture_cache_manager.h"], + deps = [ + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_buffer_storage_image_frame", hdrs = ["gpu_buffer_storage_image_frame.h"], @@ -440,6 +452,7 @@ objc_library( ":gpu_buffer_multi_pool", ":gpu_shared_data_header", ":graph_support", + ":cv_texture_cache_manager", "//mediapipe/gpu:gl_context_options_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework/port:ret_check", @@ -555,6 +568,7 @@ cc_library( "//conditions:default": [], "//mediapipe:apple": [ ":MPPGraphGPUData", + ":cv_texture_cache_manager", ], }), ) @@ -617,11 +631,13 @@ cc_library( ":gl_texture_buffer_pool", ], "//mediapipe:ios": [ + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", ], "//mediapipe:macos": [ + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", ":gl_texture_buffer", ":gl_texture_buffer_pool", diff --git a/mediapipe/gpu/cv_texture_cache_manager.cc b/mediapipe/gpu/cv_texture_cache_manager.cc new file mode 100644 index 000000000..b977a8993 --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.cc @@ -0,0 +1,55 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_texture_cache_manager.h" + +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +void CvTextureCacheManager::FlushTextureCaches() { + absl::MutexLock lock(&mutex_); + for (const auto& cache : texture_caches_) { +#if TARGET_OS_OSX + CVOpenGLTextureCacheFlush(*cache, 0); +#else + CVOpenGLESTextureCacheFlush(*cache, 0); +#endif // TARGET_OS_OSX + } +} + +void CvTextureCacheManager::RegisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == + texture_caches_.end()) + << "Attempting to register a texture cache twice"; + texture_caches_.emplace_back(cache); +} + +void CvTextureCacheManager::UnregisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); + CHECK(it != texture_caches_.end()) + << "Attempting to unregister an unknown texture cache"; + texture_caches_.erase(it); +} + +CvTextureCacheManager::~CvTextureCacheManager() { + CHECK_EQ(texture_caches_.size(), 0) + << "Failed to unregister texture caches before deleting manager"; +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_texture_cache_manager.h b/mediapipe/gpu/cv_texture_cache_manager.h new file mode 100644 index 000000000..17e44fc6e --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.h @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ +#define MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvTextureCacheManager { + public: + ~CvTextureCacheManager(); + + // TODO: add tests for the texture cache registration. + + // Inform the pool of a cache that should be flushed when it is low on + // reusable buffers. + void RegisterTextureCache(CVTextureCacheType cache); + + // Remove a texture cache from the list of caches to be flushed. + void UnregisterTextureCache(CVTextureCacheType cache); + + void FlushTextureCaches(); + + private: + absl::Mutex mutex_; + std::vector> texture_caches_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 2bceb1c05..f76833f24 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -106,20 +106,9 @@ GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { return GpuBuffer(MakeCFHolderAdopting(buffer)); } -void GpuBufferMultiPool::FlushTextureCaches() { - absl::MutexLock lock(&mutex_); - for (const auto& cache : texture_caches_) { -#if TARGET_OS_OSX - CVOpenGLTextureCacheFlush(*cache, 0); -#else - CVOpenGLESTextureCacheFlush(*cache, 0); -#endif // TARGET_OS_OSX - } -} - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return pool.GetBuffer([this]() { FlushTextureCaches(); }); + return pool.GetBuffer(flush_platform_caches_); } #else @@ -173,31 +162,4 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, } } -GpuBufferMultiPool::~GpuBufferMultiPool() { -#ifdef __APPLE__ - CHECK_EQ(texture_caches_.size(), 0) - << "Failed to unregister texture caches before deleting pool"; -#endif // defined(__APPLE__) -} - -#ifdef __APPLE__ -void GpuBufferMultiPool::RegisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == - texture_caches_.end()) - << "Attempting to register a texture cache twice"; - texture_caches_.emplace_back(cache); -} - -void GpuBufferMultiPool::UnregisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); - CHECK(it != texture_caches_.end()) - << "Attempting to unregister an unknown texture cache"; - texture_caches_.erase(it); -} -#endif // defined(__APPLE__) - } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 287b3b2a7..7317ac60e 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -43,25 +43,14 @@ class CvPixelBufferPoolWrapper; class GpuBufferMultiPool { public: GpuBufferMultiPool() {} - explicit GpuBufferMultiPool(void* ignored) {} - ~GpuBufferMultiPool(); // Obtains a buffer. May either be reused or created anew. GpuBuffer GetBuffer(int width, int height, GpuBufferFormat format = GpuBufferFormat::kBGRA32); -#ifdef __APPLE__ - // TODO: add tests for the texture cache registration. - - // Inform the pool of a cache that should be flushed when it is low on - // reusable buffers. - void RegisterTextureCache(CVTextureCacheType cache); - - // Remove a texture cache from the list of caches to be flushed. - void UnregisterTextureCache(CVTextureCacheType cache); - - void FlushTextureCaches(); -#endif // defined(__APPLE__) + void SetFlushPlatformCaches(std::function flush_platform_caches) { + flush_platform_caches_ = flush_platform_caches; + } // This class is not intended as part of the public api of this class. It is // public only because it is used as a map key type, and the map @@ -98,15 +87,10 @@ class GpuBufferMultiPool { GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); absl::Mutex mutex_; - mediapipe::ResourceCache, - absl::Hash> - cache_ ABSL_GUARDED_BY(mutex_); - -#ifdef __APPLE__ - // Texture caches used with this pool. - std::vector> texture_caches_ + mediapipe::ResourceCache> cache_ ABSL_GUARDED_BY(mutex_); -#endif // defined(__APPLE__) + // This is used to hook up the TextureCacheManager on Apple platforms. + std::function flush_platform_caches_; }; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index a8bf0c3a3..457b04fd3 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -85,7 +85,12 @@ GpuResources::GpuResources(std::shared_ptr gl_context) { named_executors_[kGpuExecutorName] = std::make_shared(gl_context.get()); #if __APPLE__ - gpu_buffer_pool().RegisterTextureCache(gl_context->cv_texture_cache()); + texture_caches_ = std::make_shared(); + gpu_buffer_pool().SetFlushPlatformCaches( + [tc = texture_caches_] { tc->FlushTextureCaches(); }); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() multiPool:&gpu_buffer_pool_]; #endif // __APPLE__ @@ -98,10 +103,12 @@ GpuResources::~GpuResources() { #if !__has_feature(objc_arc) #error This file must be built with ARC. #endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER for (auto& kv : gl_key_context_) { - gpu_buffer_pool().UnregisterTextureCache(kv.second->cv_texture_cache()); + texture_caches_->UnregisterTextureCache(kv.second->cv_texture_cache()); } -#endif +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // __APPLE__ } absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { @@ -174,9 +181,9 @@ GlContext::StatusOrGlContext GpuResources::GetOrCreateGlContext( GlContext::Create(*gl_key_context_[SharedContextKey()], kGlContextUseDedicatedThread)); it = gl_key_context_.emplace(key, new_context).first; -#if __APPLE__ - gpu_buffer_pool_.RegisterTextureCache(it->second->cv_texture_cache()); -#endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(it->second->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } return it->second; } diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 62d6bb27e..12a7a1296 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -30,6 +30,7 @@ #include "mediapipe/gpu/gpu_buffer_multi_pool.h" #ifdef __APPLE__ +#include "mediapipe/gpu/cv_texture_cache_manager.h" #ifdef __OBJC__ @class MPPGraphGPUData; #else @@ -91,6 +92,8 @@ class GpuResources { GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ + std::shared_ptr texture_caches_; + // Note that this is an Objective-C object. MPPGraphGPUData* ios_gpu_data_; #endif // defined(__APPLE__) From 0d273dd11aac9701c241f1097377614b80690fc3 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:58:32 -0800 Subject: [PATCH 036/469] Factor out CvPixelBufferPoolWrapper This is platform-specific and does not need to live in the main multi_pool sources. PiperOrigin-RevId: 488781934 --- mediapipe/gpu/BUILD | 22 ++++++ mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 71 +++++++++++++++++++ mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 50 +++++++++++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 49 +------------ mediapipe/gpu/gpu_buffer_multi_pool.h | 24 ++----- 5 files changed, 149 insertions(+), 67 deletions(-) create mode 100644 mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc create mode 100644 mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 93527b565..26df167c4 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -356,6 +356,26 @@ cc_library( ], ) +cc_library( + name = "cv_pixel_buffer_pool_wrapper", + srcs = ["cv_pixel_buffer_pool_wrapper.cc"], + hdrs = ["cv_pixel_buffer_pool_wrapper.h"], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", + ], + }), + deps = [ + ":gpu_buffer", + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_buffer_storage_image_frame", hdrs = ["gpu_buffer_storage_image_frame.h"], @@ -631,12 +651,14 @@ cc_library( ":gl_texture_buffer_pool", ], "//mediapipe:ios": [ + ":cv_pixel_buffer_pool_wrapper", ":cv_texture_cache_manager", ":pixel_buffer_pool_util", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", ], "//mediapipe:macos": [ + ":cv_pixel_buffer_pool_wrapper", ":cv_texture_cache_manager", ":pixel_buffer_pool_util", ":gl_texture_buffer", diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc new file mode 100644 index 000000000..3293b0238 --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -0,0 +1,71 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" + +#include + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper(int width, int height, + GpuBufferFormat format, + CFTimeInterval maxAge) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + pool_ = MakeCFHolderAdopting( + /* keep count is 0 because the age param keeps buffers around anyway */ + CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); +} + +GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { + CVPixelBufferRef buffer; + int threshold = 1; + NSMutableDictionary* auxAttributes = + [NSMutableDictionary dictionaryWithCapacity:1]; + CVReturn err; + bool tried_flushing = false; + while (1) { + auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); + err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, + &buffer); + if (err != kCVReturnWouldExceedAllocationThreshold) break; + if (flush && !tried_flushing) { + // Call the flush function to potentially release old holds on buffers + // and try again to create a pixel buffer. + // This is used to flush CV texture caches, which may retain buffers until + // flushed. + flush(); + tried_flushing = true; + } else { + ++threshold; + } + } + CHECK(!err) << "Error creating pixel buffer: " << err; + count_ = threshold; + return GpuBuffer(MakeCFHolderAdopting(buffer)); +} + +std::string CvPixelBufferPoolWrapper::GetDebugString() const { + auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); + return [(__bridge NSString*)*description UTF8String]; +} + +void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h new file mode 100644 index 000000000..081df4676 --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class lets calculators allocate GpuBuffers of various sizes, caching +// and reusing them as needed. It does so by automatically creating and using +// platform-specific buffer pools for the requested sizes. +// +// This class is not meant to be used directly by calculators, but is instead +// used by GlCalculatorHelper to allocate buffers. + +#ifndef MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ +#define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvPixelBufferPoolWrapper { + public: + CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, + CFTimeInterval maxAge); + GpuBuffer GetBuffer(std::function flush); + + int GetBufferCount() const { return count_; } + std::string GetDebugString() const; + + void Flush(); + + private: + CFHolder pool_; + int count_ = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index f76833f24..1909d116e 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -45,55 +45,10 @@ static constexpr int kRequestCountScrubInterval = 50; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( - const GpuBufferMultiPool::BufferSpec& spec, CFTimeInterval maxAge) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - pool_ = MakeCFHolderAdopting( - /* keep count is 0 because the age param keeps buffers around anyway */ - CreateCVPixelBufferPool(spec.width, spec.height, cv_format, 0, maxAge)); -} - -GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { - CVPixelBufferRef buffer; - int threshold = 1; - NSMutableDictionary* auxAttributes = - [NSMutableDictionary dictionaryWithCapacity:1]; - CVReturn err; - bool tried_flushing = false; - while (1) { - auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); - err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( - kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, - &buffer); - if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush && !tried_flushing) { - // Call the flush function to potentially release old holds on buffers - // and try again to create a pixel buffer. - // This is used to flush CV texture caches, which may retain buffers until - // flushed. - flush(); - tried_flushing = true; - } else { - ++threshold; - } - } - CHECK(!err) << "Error creating pixel buffer: " << err; - count_ = threshold; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - -std::string CvPixelBufferPoolWrapper::GetDebugString() const { - auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); - return [(__bridge NSString*)*description UTF8String]; -} - -void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } - std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { - return std::make_shared(spec, - kMaxInactiveBufferAge); + return std::make_shared( + spec.width, spec.height, spec.format, kMaxInactiveBufferAge); } GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 7317ac60e..f48577854 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -31,9 +31,11 @@ #include "mediapipe/gpu/pixel_buffer_pool_util.h" #endif // __APPLE__ -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" +#else #include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER namespace mediapipe { @@ -93,24 +95,6 @@ class GpuBufferMultiPool { std::function flush_platform_caches_; }; -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -class CvPixelBufferPoolWrapper { - public: - CvPixelBufferPoolWrapper(const GpuBufferMultiPool::BufferSpec& spec, - CFTimeInterval maxAge); - GpuBuffer GetBuffer(std::function flush); - - int GetBufferCount() const { return count_; } - std::string GetDebugString() const; - - void Flush(); - - private: - CFHolder pool_; - int count_ = 0; -}; -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - // BufferSpec equality operators inline bool operator==(const GpuBufferMultiPool::BufferSpec& lhs, const GpuBufferMultiPool::BufferSpec& rhs) { From a4fe3eb0941e9571bbc4ade95147c4959f8aa67f Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:59:01 -0800 Subject: [PATCH 037/469] Add CreateBufferWithoutPool method to base pools This may not fit exactly in a pool class, but it makes it easy for the multi-pool to find the appropriate method by depending only on the type of the base pool. For the CVPixelBuffer case, the buffer type is CFHolder, and it seems even less appropriate to specialize that template to add such a method there. An alternative would be to allow defining a creation function separately. PiperOrigin-RevId: 488782054 --- mediapipe/gpu/BUILD | 3 ++- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 17 ++++++++++++-- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 7 ++++-- mediapipe/gpu/gl_texture_buffer_pool.h | 5 +++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 22 +++++-------------- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 26df167c4..2f06fe1d5 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -368,10 +368,11 @@ cc_library( ], }), deps = [ - ":gpu_buffer", + ":gpu_buffer_format", ":pixel_buffer_pool_util", "//mediapipe/framework/port:logging", "//mediapipe/objc:CFHolder", + "//mediapipe/objc:util", "@com_google_absl//absl/synchronization", ], ) diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index 3293b0238..c97268307 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -19,6 +19,7 @@ #include "CoreFoundation/CFBase.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/objc/CFHolder.h" +#include "mediapipe/objc/util.h" namespace mediapipe { @@ -32,7 +33,8 @@ CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper(int width, int height, CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); } -GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { +CFHolder CvPixelBufferPoolWrapper::GetBuffer( + std::function flush) { CVPixelBufferRef buffer; int threshold = 1; NSMutableDictionary* auxAttributes = @@ -58,7 +60,7 @@ GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { } CHECK(!err) << "Error creating pixel buffer: " << err; count_ = threshold; - return GpuBuffer(MakeCFHolderAdopting(buffer)); + return MakeCFHolderAdopting(buffer); } std::string CvPixelBufferPoolWrapper::GetDebugString() const { @@ -68,4 +70,15 @@ std::string CvPixelBufferPoolWrapper::GetDebugString() const { void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } +CFHolder CvPixelBufferPoolWrapper::CreateBufferWithoutPool( + int width, int height, GpuBufferFormat format) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + CVPixelBufferRef buffer; + CVReturn err = + CreateCVPixelBufferWithoutPool(width, height, cv_format, &buffer); + CHECK(!err) << "Error creating pixel buffer: " << err; + return MakeCFHolderAdopting(buffer); +} + } // namespace mediapipe diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 081df4676..7412b776f 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -23,7 +23,7 @@ #define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ #include "CoreFoundation/CFBase.h" -#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/pixel_buffer_pool_util.h" #include "mediapipe/objc/CFHolder.h" @@ -33,13 +33,16 @@ class CvPixelBufferPoolWrapper { public: CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, CFTimeInterval maxAge); - GpuBuffer GetBuffer(std::function flush); + CFHolder GetBuffer(std::function flush); int GetBufferCount() const { return count_; } std::string GetDebugString() const; void Flush(); + static CFHolder CreateBufferWithoutPool( + int width, int height, GpuBufferFormat format); + private: CFHolder pool_; int count_ = 0; diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index 4dcad305e..cd755b4aa 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -51,6 +51,11 @@ class GlTextureBufferPool // This method is meant for testing. std::pair GetInUseAndAvailableCounts(); + static GlTextureBufferSharedPtr CreateBufferWithoutPool( + int width, int height, GpuBufferFormat format) { + return GlTextureBuffer::Create(width, height, format); + } + private: GlTextureBufferPool(int width, int height, GpuBufferFormat format, int keep_count); diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 1909d116e..fdff3e692 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -51,19 +51,9 @@ GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { spec.width, spec.height, spec.format, kMaxInactiveBufferAge); } -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - CVPixelBufferRef buffer; - CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, - cv_format, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return pool.GetBuffer(flush_platform_caches_); + return GpuBuffer(pool.GetBuffer(flush_platform_caches_)); } #else @@ -74,11 +64,6 @@ GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { kKeepCount); } -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - return GpuBuffer( - GlTextureBuffer::Create(spec.width, spec.height, spec.format)); -} - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { return GpuBuffer(pool.GetBuffer()); @@ -117,4 +102,9 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, } } +GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { + return GpuBuffer(SimplePool::CreateBufferWithoutPool(spec.width, spec.height, + spec.format)); +} + } // namespace mediapipe From 0c4522cb9fb7ce7fc940581ae2553f7282b419ca Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:59:33 -0800 Subject: [PATCH 038/469] Move flush hook to CvPixelBufferPoolWrapper constructor This unifies the implementation of GpuBufferMultiPool::GetBufferFromSimplePool. PiperOrigin-RevId: 488782173 --- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 14 +++++++------- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 6 ++++-- mediapipe/gpu/gpu_buffer_multi_pool.cc | 18 +++++++----------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index c97268307..b1c135afa 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -23,18 +23,18 @@ namespace mediapipe { -CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper(int width, int height, - GpuBufferFormat format, - CFTimeInterval maxAge) { +CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( + int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, + std::function flush_texture_caches) { OSType cv_format = CVPixelFormatForGpuBufferFormat(format); CHECK_NE(cv_format, -1) << "unsupported pixel format"; pool_ = MakeCFHolderAdopting( /* keep count is 0 because the age param keeps buffers around anyway */ CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); + flush_texture_caches_ = std::move(flush_texture_caches); } -CFHolder CvPixelBufferPoolWrapper::GetBuffer( - std::function flush) { +CFHolder CvPixelBufferPoolWrapper::GetBuffer() { CVPixelBufferRef buffer; int threshold = 1; NSMutableDictionary* auxAttributes = @@ -47,12 +47,12 @@ CFHolder CvPixelBufferPoolWrapper::GetBuffer( kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, &buffer); if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush && !tried_flushing) { + if (flush_texture_caches_ && !tried_flushing) { // Call the flush function to potentially release old holds on buffers // and try again to create a pixel buffer. // This is used to flush CV texture caches, which may retain buffers until // flushed. - flush(); + flush_texture_caches_(); tried_flushing = true; } else { ++threshold; diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 7412b776f..9d9328ca1 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -32,8 +32,9 @@ namespace mediapipe { class CvPixelBufferPoolWrapper { public: CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, - CFTimeInterval maxAge); - CFHolder GetBuffer(std::function flush); + CFTimeInterval maxAge, + std::function flush_texture_caches); + CFHolder GetBuffer(); int GetBufferCount() const { return count_; } std::string GetDebugString() const; @@ -46,6 +47,7 @@ class CvPixelBufferPoolWrapper { private: CFHolder pool_; int count_ = 0; + std::function flush_texture_caches_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index fdff3e692..9c3c9a33e 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -48,12 +48,8 @@ static constexpr int kRequestCountScrubInterval = 50; std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { return std::make_shared( - spec.width, spec.height, spec.format, kMaxInactiveBufferAge); -} - -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool.GetBuffer(flush_platform_caches_)); + spec.width, spec.height, spec.format, kMaxInactiveBufferAge, + flush_platform_caches_); } #else @@ -64,11 +60,6 @@ GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { kKeepCount); } -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool.GetBuffer()); -} - #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER std::shared_ptr GpuBufferMultiPool::RequestPool( @@ -102,6 +93,11 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, } } +GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( + BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { + return GpuBuffer(pool.GetBuffer()); +} + GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { return GpuBuffer(SimplePool::CreateBufferWithoutPool(spec.width, spec.height, spec.format)); From f13903b7c5ba53cf383f8f3c67816274f9307db0 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:01:08 -0800 Subject: [PATCH 039/469] Call SimplePool methods directly This removes redundant helper functions in GpuBufferMultiPool. PiperOrigin-RevId: 488782516 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 21 +++------------------ mediapipe/gpu/gpu_buffer_multi_pool.h | 5 +---- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 9c3c9a33e..d03ae06aa 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -21,12 +21,6 @@ #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -#include "CoreFoundation/CFBase.h" -#include "mediapipe/objc/CFHolder.h" -#include "mediapipe/objc/util.h" -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - namespace mediapipe { // Keep this many buffers allocated for a given frame size. @@ -87,20 +81,11 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, std::shared_ptr pool = RequestPool(key); if (pool) { // Note: we release our multipool lock before accessing the simple pool. - return GetBufferFromSimplePool(key, *pool); + return GpuBuffer(pool->GetBuffer()); } else { - return GetBufferWithoutPool(key); + return GpuBuffer( + SimplePool::CreateBufferWithoutPool(width, height, format)); } } -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool.GetBuffer()); -} - -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - return GpuBuffer(SimplePool::CreateBufferWithoutPool(spec.width, spec.height, - spec.format)); -} - } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index f48577854..7feb39ad4 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -82,11 +82,8 @@ class GpuBufferMultiPool { std::shared_ptr MakeSimplePool(const BufferSpec& spec); // Requests a simple buffer pool for the given spec. This may return nullptr // if we have not yet reached a sufficient number of requests to allocate a - // pool, in which case the caller should invoke GetBufferWithoutPool instead - // of GetBufferFromSimplePool. + // pool, in which case the caller should invoke CreateBufferWithoutPool. std::shared_ptr RequestPool(const BufferSpec& spec); - GpuBuffer GetBufferFromSimplePool(BufferSpec spec, SimplePool& pool); - GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); absl::Mutex mutex_; mediapipe::ResourceCache> cache_ From 7ef3185ecbb84567c8759350f1baa30907756c02 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:01:56 -0800 Subject: [PATCH 040/469] Allow customizing MultiPool options These don't need to be constants. PiperOrigin-RevId: 488782713 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 23 +++++------------------ mediapipe/gpu/gpu_buffer_multi_pool.h | 22 +++++++++++++++++++++- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index d03ae06aa..44f1d40df 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -23,26 +23,12 @@ namespace mediapipe { -// Keep this many buffers allocated for a given frame size. -static constexpr int kKeepCount = 2; -// The maximum size of the GpuBufferMultiPool. When the limit is reached, the -// oldest BufferSpec will be dropped. -static constexpr int kMaxPoolCount = 10; -// Time in seconds after which an inactive buffer can be dropped from the pool. -// Currently only used with CVPixelBufferPool. -static constexpr float kMaxInactiveBufferAge = 0.25; -// Skip allocating a buffer pool until at least this many requests have been -// made for a given BufferSpec. -static constexpr int kMinRequestsBeforePool = 2; -// Do a deeper flush every this many requests. -static constexpr int kRequestCountScrubInterval = 50; - #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { return std::make_shared( - spec.width, spec.height, spec.format, kMaxInactiveBufferAge, + spec.width, spec.height, spec.format, options_.max_inactive_buffer_age, flush_platform_caches_); } @@ -51,7 +37,7 @@ GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - kKeepCount); + options_.keep_count); } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER @@ -64,11 +50,12 @@ std::shared_ptr GpuBufferMultiPool::RequestPool( absl::MutexLock lock(&mutex_); pool = cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { - return (request_count >= kMinRequestsBeforePool) + return (request_count >= options_.min_requests_before_pool) ? MakeSimplePool(spec) : nullptr; }); - evicted = cache_.Evict(kMaxPoolCount, kRequestCountScrubInterval); + evicted = cache_.Evict(options_.max_pool_count, + options_.request_count_scrub_interval); } // Evicted pools, and their buffers, will be released without holding the // lock. diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 7feb39ad4..1396bcdb3 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -42,9 +42,28 @@ namespace mediapipe { struct GpuSharedData; class CvPixelBufferPoolWrapper; +struct MultiPoolOptions { + // Keep this many buffers allocated for a given frame size. + int keep_count = 2; + // The maximum size of the GpuBufferMultiPool. When the limit is reached, the + // oldest BufferSpec will be dropped. + int max_pool_count = 10; + // Time in seconds after which an inactive buffer can be dropped from the + // pool. Currently only used with CVPixelBufferPool. + float max_inactive_buffer_age = 0.25; + // Skip allocating a buffer pool until at least this many requests have been + // made for a given BufferSpec. + int min_requests_before_pool = 2; + // Do a deeper flush every this many requests. + int request_count_scrub_interval = 50; +}; + +static constexpr MultiPoolOptions kDefaultMultiPoolOptions; + class GpuBufferMultiPool { public: - GpuBufferMultiPool() {} + GpuBufferMultiPool(MultiPoolOptions options = kDefaultMultiPoolOptions) + : options_(options) {} // Obtains a buffer. May either be reused or created anew. GpuBuffer GetBuffer(int width, int height, @@ -85,6 +104,7 @@ class GpuBufferMultiPool { // pool, in which case the caller should invoke CreateBufferWithoutPool. std::shared_ptr RequestPool(const BufferSpec& spec); + MultiPoolOptions options_; absl::Mutex mutex_; mediapipe::ResourceCache> cache_ ABSL_GUARDED_BY(mutex_); From 267476657d18598dc993dc6bb7f5f084a951d8ff Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:02:32 -0800 Subject: [PATCH 041/469] MultiPool options header refactoring Passing MultiPool options to the base pool factories means that we don't have to specialize which options we pass to them. PiperOrigin-RevId: 488782861 --- mediapipe/gpu/BUILD | 8 ++++ mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 11 +++++ mediapipe/gpu/gl_texture_buffer_pool.h | 7 +++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 23 ++++------ mediapipe/gpu/gpu_buffer_multi_pool.h | 24 +++------- mediapipe/gpu/multi_pool.h | 47 ++++++++++++++++++++ 6 files changed, 86 insertions(+), 34 deletions(-) create mode 100644 mediapipe/gpu/multi_pool.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 2f06fe1d5..b94623ca5 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -369,6 +369,7 @@ cc_library( }), deps = [ ":gpu_buffer_format", + ":multi_pool", ":pixel_buffer_pool_util", "//mediapipe/framework/port:logging", "//mediapipe/objc:CFHolder", @@ -604,6 +605,7 @@ cc_library( ":gl_texture_buffer", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -612,6 +614,11 @@ cc_library( ], ) +cc_library( + name = "multi_pool", + hdrs = ["multi_pool.h"], +) + cc_library( name = "gpu_buffer_multi_pool", srcs = ["gpu_buffer_multi_pool.cc"], @@ -639,6 +646,7 @@ cc_library( ":gl_base", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 9d9328ca1..185ba37c6 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -24,6 +24,7 @@ #include "CoreFoundation/CFBase.h" #include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/multi_pool.h" #include "mediapipe/gpu/pixel_buffer_pool_util.h" #include "mediapipe/objc/CFHolder.h" @@ -34,6 +35,16 @@ class CvPixelBufferPoolWrapper { CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, std::function flush_texture_caches); + + static std::shared_ptr Create( + int width, int height, GpuBufferFormat format, + const MultiPoolOptions& options, + std::function flush_texture_caches = nullptr) { + return std::make_shared( + width, height, format, options.max_inactive_buffer_age, + flush_texture_caches); + } + CFHolder GetBuffer(); int GetBufferCount() const { return count_; } diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index cd755b4aa..fee46915e 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -23,6 +23,7 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gl_texture_buffer.h" +#include "mediapipe/gpu/multi_pool.h" namespace mediapipe { @@ -40,6 +41,12 @@ class GlTextureBufferPool new GlTextureBufferPool(width, height, format, keep_count)); } + static std::shared_ptr Create( + int width, int height, GpuBufferFormat format, + const MultiPoolOptions& options) { + return Create(width, height, format, options.keep_count); + } + // Obtains a buffers. May either be reused or created anew. // A GlContext must be current when this is called. GlTextureBufferSharedPtr GetBuffer(); diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 44f1d40df..df228b7dd 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -23,24 +23,17 @@ namespace mediapipe { +std::shared_ptr +GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options) { #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -std::shared_ptr -GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { - return std::make_shared( - spec.width, spec.height, spec.format, options_.max_inactive_buffer_age, - flush_platform_caches_); -} - + return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, spec.format, + options, flush_platform_caches_); #else - -std::shared_ptr -GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - options_.keep_count); -} - + options); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +} std::shared_ptr GpuBufferMultiPool::RequestPool( const BufferSpec& spec) { @@ -51,7 +44,7 @@ std::shared_ptr GpuBufferMultiPool::RequestPool( pool = cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { return (request_count >= options_.min_requests_before_pool) - ? MakeSimplePool(spec) + ? MakeSimplePool(spec, options_) : nullptr; }); evicted = cache_.Evict(options_.max_pool_count, diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 1396bcdb3..3ea299f78 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -25,6 +25,7 @@ #include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/multi_pool.h" #include "mediapipe/util/resource_cache.h" #ifdef __APPLE__ @@ -42,24 +43,6 @@ namespace mediapipe { struct GpuSharedData; class CvPixelBufferPoolWrapper; -struct MultiPoolOptions { - // Keep this many buffers allocated for a given frame size. - int keep_count = 2; - // The maximum size of the GpuBufferMultiPool. When the limit is reached, the - // oldest BufferSpec will be dropped. - int max_pool_count = 10; - // Time in seconds after which an inactive buffer can be dropped from the - // pool. Currently only used with CVPixelBufferPool. - float max_inactive_buffer_age = 0.25; - // Skip allocating a buffer pool until at least this many requests have been - // made for a given BufferSpec. - int min_requests_before_pool = 2; - // Do a deeper flush every this many requests. - int request_count_scrub_interval = 50; -}; - -static constexpr MultiPoolOptions kDefaultMultiPoolOptions; - class GpuBufferMultiPool { public: GpuBufferMultiPool(MultiPoolOptions options = kDefaultMultiPoolOptions) @@ -98,7 +81,10 @@ class GpuBufferMultiPool { using SimplePool = GlTextureBufferPool; #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - std::shared_ptr MakeSimplePool(const BufferSpec& spec); + std::shared_ptr MakeSimplePool( + const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options); + // Requests a simple buffer pool for the given spec. This may return nullptr // if we have not yet reached a sufficient number of requests to allocate a // pool, in which case the caller should invoke CreateBufferWithoutPool. diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h new file mode 100644 index 000000000..e504fc820 --- /dev/null +++ b/mediapipe/gpu/multi_pool.h @@ -0,0 +1,47 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class lets calculators allocate GpuBuffers of various sizes, caching +// and reusing them as needed. It does so by automatically creating and using +// platform-specific buffer pools for the requested sizes. +// +// This class is not meant to be used directly by calculators, but is instead +// used by GlCalculatorHelper to allocate buffers. + +#ifndef MEDIAPIPE_GPU_MULTI_POOL_H_ +#define MEDIAPIPE_GPU_MULTI_POOL_H_ + +namespace mediapipe { + +struct MultiPoolOptions { + // Keep this many buffers allocated for a given frame size. + int keep_count = 2; + // The maximum size of the GpuBufferMultiPool. When the limit is reached, the + // oldest BufferSpec will be dropped. + int max_pool_count = 10; + // Time in seconds after which an inactive buffer can be dropped from the + // pool. Currently only used with CVPixelBufferPool. + float max_inactive_buffer_age = 0.25; + // Skip allocating a buffer pool until at least this many requests have been + // made for a given BufferSpec. + int min_requests_before_pool = 2; + // Do a deeper flush every this many requests. + int request_count_scrub_interval = 50; +}; + +static constexpr MultiPoolOptions kDefaultMultiPoolOptions; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_MULTI_POOL_H_ From b9fa2e3496ad0879556162a738f4f608ebe1bb5b Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:03:03 -0800 Subject: [PATCH 042/469] Make it possible to override the SimplePool factory used by MultiPool This means MultiPool no longer needs a SetFlushPlatformCaches method, which was too specific to the CVPixelBufferPool. PiperOrigin-RevId: 488783003 --- mediapipe/gpu/BUILD | 1 + mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 8 ++++---- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 10 +++++----- mediapipe/gpu/gpu_buffer_multi_pool.cc | 15 +++++---------- mediapipe/gpu/gpu_buffer_multi_pool.h | 18 ++++++++++-------- mediapipe/gpu/gpu_shared_data_internal.cc | 8 ++++++-- 6 files changed, 31 insertions(+), 29 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index b94623ca5..36527736b 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -368,6 +368,7 @@ cc_library( ], }), deps = [ + ":cv_texture_cache_manager", ":gpu_buffer_format", ":multi_pool", ":pixel_buffer_pool_util", diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index b1c135afa..d8155f5cf 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -25,13 +25,13 @@ namespace mediapipe { CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, - std::function flush_texture_caches) { + CvTextureCacheManager* texture_caches) { OSType cv_format = CVPixelFormatForGpuBufferFormat(format); CHECK_NE(cv_format, -1) << "unsupported pixel format"; pool_ = MakeCFHolderAdopting( /* keep count is 0 because the age param keeps buffers around anyway */ CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); - flush_texture_caches_ = std::move(flush_texture_caches); + texture_caches_ = texture_caches; } CFHolder CvPixelBufferPoolWrapper::GetBuffer() { @@ -47,12 +47,12 @@ CFHolder CvPixelBufferPoolWrapper::GetBuffer() { kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, &buffer); if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush_texture_caches_ && !tried_flushing) { + if (texture_caches_ && !tried_flushing) { // Call the flush function to potentially release old holds on buffers // and try again to create a pixel buffer. // This is used to flush CV texture caches, which may retain buffers until // flushed. - flush_texture_caches_(); + texture_caches_->FlushTextureCaches(); tried_flushing = true; } else { ++threshold; diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 185ba37c6..7d0aec4eb 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -23,6 +23,7 @@ #define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ #include "CoreFoundation/CFBase.h" +#include "mediapipe/gpu/cv_texture_cache_manager.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/multi_pool.h" #include "mediapipe/gpu/pixel_buffer_pool_util.h" @@ -34,15 +35,14 @@ class CvPixelBufferPoolWrapper { public: CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, - std::function flush_texture_caches); + CvTextureCacheManager* texture_caches); static std::shared_ptr Create( int width, int height, GpuBufferFormat format, const MultiPoolOptions& options, - std::function flush_texture_caches = nullptr) { + CvTextureCacheManager* texture_caches = nullptr) { return std::make_shared( - width, height, format, options.max_inactive_buffer_age, - flush_texture_caches); + width, height, format, options.max_inactive_buffer_age, texture_caches); } CFHolder GetBuffer(); @@ -58,7 +58,7 @@ class CvPixelBufferPoolWrapper { private: CFHolder pool_; int count_ = 0; - std::function flush_texture_caches_; + CvTextureCacheManager* texture_caches_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index df228b7dd..744ccea2d 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -24,15 +24,10 @@ namespace mediapipe { std::shared_ptr -GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, spec.format, - options, flush_platform_caches_); -#else - return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - options); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +GpuBufferMultiPool::DefaultMakeSimplePool( + const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options) { + return SimplePool::Create(spec.width, spec.height, spec.format, options); } std::shared_ptr GpuBufferMultiPool::RequestPool( @@ -44,7 +39,7 @@ std::shared_ptr GpuBufferMultiPool::RequestPool( pool = cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { return (request_count >= options_.min_requests_before_pool) - ? MakeSimplePool(spec, options_) + ? create_simple_pool_(spec, options_) : nullptr; }); evicted = cache_.Evict(options_.max_pool_count, diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 3ea299f78..88428d053 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -52,10 +52,6 @@ class GpuBufferMultiPool { GpuBuffer GetBuffer(int width, int height, GpuBufferFormat format = GpuBufferFormat::kBGRA32); - void SetFlushPlatformCaches(std::function flush_platform_caches) { - flush_platform_caches_ = flush_platform_caches; - } - // This class is not intended as part of the public api of this class. It is // public only because it is used as a map key type, and the map // implementation needs access to, e.g., the equality operator. @@ -74,14 +70,21 @@ class GpuBufferMultiPool { mediapipe::GpuBufferFormat format; }; - private: #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER using SimplePool = CvPixelBufferPoolWrapper; #else using SimplePool = GlTextureBufferPool; #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - std::shared_ptr MakeSimplePool( + using SimplePoolFactory = std::function( + const BufferSpec& spec, const MultiPoolOptions& options)>; + + void SetSimplePoolFactory(SimplePoolFactory create_simple_pool) { + create_simple_pool_ = create_simple_pool; + } + + private: + static std::shared_ptr DefaultMakeSimplePool( const GpuBufferMultiPool::BufferSpec& spec, const MultiPoolOptions& options); @@ -94,8 +97,7 @@ class GpuBufferMultiPool { absl::Mutex mutex_; mediapipe::ResourceCache> cache_ ABSL_GUARDED_BY(mutex_); - // This is used to hook up the TextureCacheManager on Apple platforms. - std::function flush_platform_caches_; + SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; }; // BufferSpec equality operators diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 457b04fd3..6633c2f00 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -86,8 +86,12 @@ GpuResources::GpuResources(std::shared_ptr gl_context) { std::make_shared(gl_context.get()); #if __APPLE__ texture_caches_ = std::make_shared(); - gpu_buffer_pool().SetFlushPlatformCaches( - [tc = texture_caches_] { tc->FlushTextureCaches(); }); + gpu_buffer_pool().SetSimplePoolFactory( + [tc = texture_caches_](const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options) { + return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, + spec.format, options, tc.get()); + }); #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER From 53d015af08c96d39ecee97bdfa11cc5b5a882cec Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:03:41 -0800 Subject: [PATCH 043/469] Generic MultiPool template PiperOrigin-RevId: 488783176 --- mediapipe/gpu/BUILD | 1 + mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 8 +- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 8 +- mediapipe/gpu/gl_texture_buffer_pool.h | 9 +- mediapipe/gpu/gpu_buffer_format.h | 28 +++++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 46 +--------- mediapipe/gpu/gpu_buffer_multi_pool.h | 77 ++--------------- mediapipe/gpu/gpu_shared_data_internal.cc | 18 ++-- mediapipe/gpu/gpu_shared_data_internal.h | 6 +- mediapipe/gpu/multi_pool.h | 84 +++++++++++++++++-- 10 files changed, 142 insertions(+), 143 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 36527736b..1efe75b52 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -618,6 +618,7 @@ cc_library( cc_library( name = "multi_pool", hdrs = ["multi_pool.h"], + deps = ["//mediapipe/util:resource_cache"], ) cc_library( diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index d8155f5cf..6e077ae6e 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -71,12 +71,12 @@ std::string CvPixelBufferPoolWrapper::GetDebugString() const { void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } CFHolder CvPixelBufferPoolWrapper::CreateBufferWithoutPool( - int width, int height, GpuBufferFormat format) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + const internal::GpuBufferSpec& spec) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); CHECK_NE(cv_format, -1) << "unsupported pixel format"; CVPixelBufferRef buffer; - CVReturn err = - CreateCVPixelBufferWithoutPool(width, height, cv_format, &buffer); + CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, + cv_format, &buffer); CHECK(!err) << "Error creating pixel buffer: " << err; return MakeCFHolderAdopting(buffer); } diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 7d0aec4eb..4d71adbf2 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -38,11 +38,11 @@ class CvPixelBufferPoolWrapper { CvTextureCacheManager* texture_caches); static std::shared_ptr Create( - int width, int height, GpuBufferFormat format, - const MultiPoolOptions& options, + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options, CvTextureCacheManager* texture_caches = nullptr) { return std::make_shared( - width, height, format, options.max_inactive_buffer_age, texture_caches); + spec.width, spec.height, spec.format, options.max_inactive_buffer_age, + texture_caches); } CFHolder GetBuffer(); @@ -53,7 +53,7 @@ class CvPixelBufferPoolWrapper { void Flush(); static CFHolder CreateBufferWithoutPool( - int width, int height, GpuBufferFormat format); + const internal::GpuBufferSpec& spec); private: CFHolder pool_; diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index fee46915e..29fc3c01c 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -42,9 +42,8 @@ class GlTextureBufferPool } static std::shared_ptr Create( - int width, int height, GpuBufferFormat format, - const MultiPoolOptions& options) { - return Create(width, height, format, options.keep_count); + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options) { + return Create(spec.width, spec.height, spec.format, options.keep_count); } // Obtains a buffers. May either be reused or created anew. @@ -59,8 +58,8 @@ class GlTextureBufferPool std::pair GetInUseAndAvailableCounts(); static GlTextureBufferSharedPtr CreateBufferWithoutPool( - int width, int height, GpuBufferFormat format) { - return GlTextureBuffer::Create(width, height, format); + const internal::GpuBufferSpec& spec) { + return GlTextureBuffer::Create(spec.width, spec.height, spec.format); } private: diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 45f054d31..06c5a0439 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -153,6 +153,34 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) { #endif // __APPLE__ +namespace internal { + +struct GpuBufferSpec { + GpuBufferSpec(int w, int h, GpuBufferFormat f) + : width(w), height(h), format(f) {} + + template + friend H AbslHashValue(H h, const GpuBufferSpec& spec) { + return H::combine(std::move(h), spec.width, spec.height, + static_cast(spec.format)); + } + + int width; + int height; + GpuBufferFormat format; +}; + +// BufferSpec equality operators +inline bool operator==(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return lhs.width == rhs.width && lhs.height == rhs.height && + lhs.format == rhs.format; +} +inline bool operator!=(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace internal + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_FORMAT_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 744ccea2d..e2ed523e4 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -16,51 +16,7 @@ #include -#include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -namespace mediapipe { - -std::shared_ptr -GpuBufferMultiPool::DefaultMakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options) { - return SimplePool::Create(spec.width, spec.height, spec.format, options); -} - -std::shared_ptr GpuBufferMultiPool::RequestPool( - const BufferSpec& spec) { - std::shared_ptr pool; - std::vector> evicted; - { - absl::MutexLock lock(&mutex_); - pool = - cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { - return (request_count >= options_.min_requests_before_pool) - ? create_simple_pool_(spec, options_) - : nullptr; - }); - evicted = cache_.Evict(options_.max_pool_count, - options_.request_count_scrub_interval); - } - // Evicted pools, and their buffers, will be released without holding the - // lock. - return pool; -} - -GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, - GpuBufferFormat format) { - BufferSpec key(width, height, format); - std::shared_ptr pool = RequestPool(key); - if (pool) { - // Note: we release our multipool lock before accessing the simple pool. - return GpuBuffer(pool->GetBuffer()); - } else { - return GpuBuffer( - SimplePool::CreateBufferWithoutPool(width, height, format)); - } -} - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 88428d053..827cf514a 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -22,15 +22,9 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ -#include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/multi_pool.h" -#include "mediapipe/util/resource_cache.h" - -#ifdef __APPLE__ -#include "mediapipe/gpu/pixel_buffer_pool_util.h" -#endif // __APPLE__ #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" @@ -40,77 +34,24 @@ namespace mediapipe { -struct GpuSharedData; class CvPixelBufferPoolWrapper; -class GpuBufferMultiPool { - public: - GpuBufferMultiPool(MultiPoolOptions options = kDefaultMultiPoolOptions) - : options_(options) {} - - // Obtains a buffer. May either be reused or created anew. - GpuBuffer GetBuffer(int width, int height, - GpuBufferFormat format = GpuBufferFormat::kBGRA32); - - // This class is not intended as part of the public api of this class. It is - // public only because it is used as a map key type, and the map - // implementation needs access to, e.g., the equality operator. - struct BufferSpec { - BufferSpec(int w, int h, mediapipe::GpuBufferFormat f) - : width(w), height(h), format(f) {} - - template - friend H AbslHashValue(H h, const BufferSpec& spec) { - return H::combine(std::move(h), spec.width, spec.height, - static_cast(spec.format)); - } - - int width; - int height; - mediapipe::GpuBufferFormat format; - }; - +class GpuBufferMultiPool : public MultiPool< #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - using SimplePool = CvPixelBufferPoolWrapper; + CvPixelBufferPoolWrapper, #else - using SimplePool = GlTextureBufferPool; + GlTextureBufferPool, #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + internal::GpuBufferSpec, GpuBuffer> { + public: + using MultiPool::MultiPool; - using SimplePoolFactory = std::function( - const BufferSpec& spec, const MultiPoolOptions& options)>; - - void SetSimplePoolFactory(SimplePoolFactory create_simple_pool) { - create_simple_pool_ = create_simple_pool; + GpuBuffer GetBuffer(int width, int height, + GpuBufferFormat format = GpuBufferFormat::kBGRA32) { + return Get(internal::GpuBufferSpec(width, height, format)); } - - private: - static std::shared_ptr DefaultMakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options); - - // Requests a simple buffer pool for the given spec. This may return nullptr - // if we have not yet reached a sufficient number of requests to allocate a - // pool, in which case the caller should invoke CreateBufferWithoutPool. - std::shared_ptr RequestPool(const BufferSpec& spec); - - MultiPoolOptions options_; - absl::Mutex mutex_; - mediapipe::ResourceCache> cache_ - ABSL_GUARDED_BY(mutex_); - SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; }; -// BufferSpec equality operators -inline bool operator==(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return lhs.width == rhs.width && lhs.height == rhs.height && - lhs.format == rhs.format; -} -inline bool operator!=(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return !operator==(lhs, rhs); -} - } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 6633c2f00..52db88633 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -80,18 +80,20 @@ GpuResources::StatusOrGpuResources GpuResources::Create( return gpu_resources; } -GpuResources::GpuResources(std::shared_ptr gl_context) { +GpuResources::GpuResources(std::shared_ptr gl_context) +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + : texture_caches_(std::make_shared()), + gpu_buffer_pool_( + [tc = texture_caches_](const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) { + return CvPixelBufferPoolWrapper::Create(spec, options, tc.get()); + }) +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +{ gl_key_context_[SharedContextKey()] = gl_context; named_executors_[kGpuExecutorName] = std::make_shared(gl_context.get()); #if __APPLE__ - texture_caches_ = std::make_shared(); - gpu_buffer_pool().SetSimplePoolFactory( - [tc = texture_caches_](const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options) { - return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, - spec.format, options, tc.get()); - }); #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 12a7a1296..4fe6ba04e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -87,13 +87,15 @@ class GpuResources { std::map node_key_; std::map> gl_key_context_; +#ifdef MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + std::shared_ptr texture_caches_; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + // The pool must be destructed before the gl_context, but after the // ios_gpu_data, so the declaration order is important. GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ - std::shared_ptr texture_caches_; - // Note that this is an Objective-C object. MPPGraphGPUData* ios_gpu_data_; #endif // defined(__APPLE__) diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h index e504fc820..8a3cf6be0 100644 --- a/mediapipe/gpu/multi_pool.h +++ b/mediapipe/gpu/multi_pool.h @@ -12,16 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This class lets calculators allocate GpuBuffers of various sizes, caching -// and reusing them as needed. It does so by automatically creating and using -// platform-specific buffer pools for the requested sizes. -// -// This class is not meant to be used directly by calculators, but is instead -// used by GlCalculatorHelper to allocate buffers. - #ifndef MEDIAPIPE_GPU_MULTI_POOL_H_ #define MEDIAPIPE_GPU_MULTI_POOL_H_ +#include "mediapipe/util/resource_cache.h" + namespace mediapipe { struct MultiPoolOptions { @@ -42,6 +37,81 @@ struct MultiPoolOptions { static constexpr MultiPoolOptions kDefaultMultiPoolOptions; +// MultiPool is a generic class for vending reusable resources of type Item, +// which are assumed to be relatively expensive to create, so that reusing them +// is beneficial. +// Items are classified by Spec; when an item with a given Spec is requested, +// an old Item with the same Spec can be reused, if available; otherwise a new +// Item will be created. When user code is done with an Item, it is returned +// to the pool for reuse. +// In order to manage this, a MultiPool contains a map of Specs to SimplePool; +// each SimplePool manages Items with the same Spec, which are thus considered +// interchangeable. +// Item retention and eviction policies are controlled by options. +// A concrete example would be a pool of GlTextureBuffer, grouped by dimensions +// and format. +template +class MultiPool { + public: + using SimplePoolFactory = std::function( + const Spec& spec, const MultiPoolOptions& options)>; + + MultiPool(SimplePoolFactory factory = DefaultMakeSimplePool, + MultiPoolOptions options = kDefaultMultiPoolOptions) + : create_simple_pool_(factory), options_(options) {} + + // Obtains an item. May either be reused or created anew. + Item Get(const Spec& spec); + + private: + static std::shared_ptr DefaultMakeSimplePool( + const Spec& spec, const MultiPoolOptions& options) { + return SimplePool::Create(spec, options); + } + + // Requests a simple buffer pool for the given spec. This may return nullptr + // if we have not yet reached a sufficient number of requests to allocate a + // pool, in which case the caller should invoke CreateBufferWithoutPool. + std::shared_ptr RequestPool(const Spec& spec); + + absl::Mutex mutex_; + mediapipe::ResourceCache> cache_ + ABSL_GUARDED_BY(mutex_); + SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; + MultiPoolOptions options_; +}; + +template +std::shared_ptr MultiPool::RequestPool( + const Spec& spec) { + std::shared_ptr pool; + std::vector> evicted; + { + absl::MutexLock lock(&mutex_); + pool = cache_.Lookup(spec, [this](const Spec& spec, int request_count) { + return (request_count >= options_.min_requests_before_pool) + ? create_simple_pool_(spec, options_) + : nullptr; + }); + evicted = cache_.Evict(options_.max_pool_count, + options_.request_count_scrub_interval); + } + // Evicted pools, and their buffers, will be released without holding the + // lock. + return pool; +} + +template +Item MultiPool::Get(const Spec& spec) { + std::shared_ptr pool = RequestPool(spec); + if (pool) { + // Note: we release our multipool lock before accessing the simple pool. + return Item(pool->GetBuffer()); + } else { + return Item(SimplePool::CreateBufferWithoutPool(spec)); + } +} + } // namespace mediapipe #endif // MEDIAPIPE_GPU_MULTI_POOL_H_ From ab074a579a206164384f36e76581e784b8a65bd3 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:04:11 -0800 Subject: [PATCH 044/469] Internal change PiperOrigin-RevId: 488783325 --- WORKSPACE | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index fea96d941..d43394883 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -26,7 +26,7 @@ versions.check(minimum_bazel_version = "3.7.2") http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20210324.2.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz", ], # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ @@ -35,8 +35,8 @@ http_archive( patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20210324.2", - sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f" + strip_prefix = "abseil-cpp-20220623.1", + sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8" ) http_archive( From 583d27636b346ae2e69c4b12f1346e2a8c32401c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:04:45 -0800 Subject: [PATCH 045/469] Factor out ReusablePool PiperOrigin-RevId: 488783477 --- mediapipe/gpu/BUILD | 11 ++ mediapipe/gpu/gl_texture_buffer.h | 5 + mediapipe/gpu/gl_texture_buffer_pool.cc | 77 +------------ mediapipe/gpu/gl_texture_buffer_pool.h | 52 +++------ mediapipe/gpu/reusable_pool.h | 145 ++++++++++++++++++++++++ 5 files changed, 178 insertions(+), 112 deletions(-) create mode 100644 mediapipe/gpu/reusable_pool.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 1efe75b52..747d131ba 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -607,6 +607,7 @@ cc_library( ":gpu_buffer", ":gpu_shared_data_header", ":multi_pool", + ":reusable_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -615,6 +616,16 @@ cc_library( ], ) +cc_library( + name = "reusable_pool", + hdrs = ["reusable_pool.h"], + deps = [ + ":multi_pool", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "multi_pool", hdrs = ["multi_pool.h"], diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index 124a0ec2f..a770163b5 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -71,6 +71,11 @@ class GlTextureBuffer // Create a texture with a copy of the data in image_frame. static std::unique_ptr Create(const ImageFrame& image_frame); + static std::unique_ptr Create( + const internal::GpuBufferSpec& spec) { + return Create(spec.width, spec.height, spec.format); + } + // Wraps an existing texture, but does not take ownership of it. // deletion_callback is invoked when the GlTextureBuffer is released, so // the caller knows that the texture is no longer in use. diff --git a/mediapipe/gpu/gl_texture_buffer_pool.cc b/mediapipe/gpu/gl_texture_buffer_pool.cc index 3d5a8cdaa..599381a34 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.cc +++ b/mediapipe/gpu/gl_texture_buffer_pool.cc @@ -16,79 +16,4 @@ #include "absl/synchronization/mutex.h" -namespace mediapipe { - -GlTextureBufferPool::GlTextureBufferPool(int width, int height, - GpuBufferFormat format, int keep_count) - : width_(width), - height_(height), - format_(format), - keep_count_(keep_count) {} - -GlTextureBufferSharedPtr GlTextureBufferPool::GetBuffer() { - std::unique_ptr buffer; - bool reuse = false; - - { - absl::MutexLock lock(&mutex_); - if (available_.empty()) { - buffer = GlTextureBuffer::Create(width_, height_, format_); - if (!buffer) return nullptr; - } else { - buffer = std::move(available_.back()); - available_.pop_back(); - reuse = true; - } - - ++in_use_count_; - } - - // This needs to wait on consumer sync points, therefore it should not be - // done while holding the mutex. - if (reuse) { - buffer->Reuse(); - } - - // Return a shared_ptr with a custom deleter that adds the buffer back - // to our available list. - std::weak_ptr weak_pool(shared_from_this()); - return std::shared_ptr( - buffer.release(), [weak_pool](GlTextureBuffer* buf) { - auto pool = weak_pool.lock(); - if (pool) { - pool->Return(absl::WrapUnique(buf)); - } else { - delete buf; - } - }); -} - -std::pair GlTextureBufferPool::GetInUseAndAvailableCounts() { - absl::MutexLock lock(&mutex_); - return {in_use_count_, available_.size()}; -} - -void GlTextureBufferPool::Return(std::unique_ptr buf) { - std::vector> trimmed; - { - absl::MutexLock lock(&mutex_); - --in_use_count_; - available_.emplace_back(std::move(buf)); - TrimAvailable(&trimmed); - } - // The trimmed buffers will be released without holding the lock. -} - -void GlTextureBufferPool::TrimAvailable( - std::vector>* trimmed) { - int keep = std::max(keep_count_ - in_use_count_, 0); - if (available_.size() > keep) { - auto trim_it = std::next(available_.begin(), keep); - if (trimmed) { - std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); - } - available_.erase(trim_it, available_.end()); - } -} - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index 29fc3c01c..726d0528d 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -24,11 +24,11 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/multi_pool.h" +#include "mediapipe/gpu/reusable_pool.h" namespace mediapipe { -class GlTextureBufferPool - : public std::enable_shared_from_this { +class GlTextureBufferPool : public ReusablePool { public: // Creates a pool. This pool will manage buffers of the specified dimensions, // and will keep keep_count buffers around for reuse. @@ -37,52 +37,32 @@ class GlTextureBufferPool static std::shared_ptr Create(int width, int height, GpuBufferFormat format, int keep_count) { - return std::shared_ptr( - new GlTextureBufferPool(width, height, format, keep_count)); + return Create({width, height, format}, {.keep_count = keep_count}); } static std::shared_ptr Create( const internal::GpuBufferSpec& spec, const MultiPoolOptions& options) { - return Create(spec.width, spec.height, spec.format, options.keep_count); + return std::shared_ptr( + new GlTextureBufferPool(spec, options)); } - // Obtains a buffers. May either be reused or created anew. - // A GlContext must be current when this is called. - GlTextureBufferSharedPtr GetBuffer(); - - int width() const { return width_; } - int height() const { return height_; } - GpuBufferFormat format() const { return format_; } - - // This method is meant for testing. - std::pair GetInUseAndAvailableCounts(); + int width() const { return spec_.width; } + int height() const { return spec_.height; } + GpuBufferFormat format() const { return spec_.format; } static GlTextureBufferSharedPtr CreateBufferWithoutPool( const internal::GpuBufferSpec& spec) { - return GlTextureBuffer::Create(spec.width, spec.height, spec.format); + return GlTextureBuffer::Create(spec); } - private: - GlTextureBufferPool(int width, int height, GpuBufferFormat format, - int keep_count); + protected: + GlTextureBufferPool(const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) + : ReusablePool( + [this] { return GlTextureBuffer::Create(spec_); }, options), + spec_(spec) {} - // Return a buffer to the pool. - void Return(std::unique_ptr buf); - - // If the total number of buffers is greater than keep_count, destroys any - // surplus buffers that are no longer in use. - void TrimAvailable(std::vector>* trimmed) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - const int width_; - const int height_; - const GpuBufferFormat format_; - const int keep_count_; - - absl::Mutex mutex_; - int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; - std::vector> available_ - ABSL_GUARDED_BY(mutex_); + const internal::GpuBufferSpec spec_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/reusable_pool.h b/mediapipe/gpu/reusable_pool.h new file mode 100644 index 000000000..ddeaa5ba7 --- /dev/null +++ b/mediapipe/gpu/reusable_pool.h @@ -0,0 +1,145 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Consider this file an implementation detail. None of this is part of the +// public API. + +#ifndef MEDIAPIPE_GPU_REUSABLE_POOL_H_ +#define MEDIAPIPE_GPU_REUSABLE_POOL_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/multi_pool.h" + +namespace mediapipe { + +template +class ReusablePool : public std::enable_shared_from_this> { + public: + using ItemFactory = absl::AnyInvocable() const>; + + // Creates a pool. This pool will manage buffers of the specified dimensions, + // and will keep keep_count buffers around for reuse. + // We enforce creation as a shared_ptr so that we can use a weak reference in + // the buffers' deleters. + static std::shared_ptr> Create( + ItemFactory item_factory, const MultiPoolOptions& options) { + return std::shared_ptr>( + new ReusablePool(std::move(item_factory), options)); + } + + // Obtains a buffer. May either be reused or created anew. + // A GlContext must be current when this is called. + std::shared_ptr GetBuffer(); + + // This method is meant for testing. + std::pair GetInUseAndAvailableCounts(); + + protected: + ReusablePool(ItemFactory item_factory, const MultiPoolOptions& options) + : item_factory_(std::move(item_factory)), + keep_count_(options.keep_count) {} + + private: + // Return a buffer to the pool. + void Return(std::unique_ptr buf); + + // If the total number of buffers is greater than keep_count, destroys any + // surplus buffers that are no longer in use. + void TrimAvailable(std::vector>* trimmed) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + const ItemFactory item_factory_; + const int keep_count_; + + absl::Mutex mutex_; + int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; + std::vector> available_ ABSL_GUARDED_BY(mutex_); +}; + +template +inline std::shared_ptr ReusablePool::GetBuffer() { + std::unique_ptr buffer; + bool reuse = false; + + { + absl::MutexLock lock(&mutex_); + if (available_.empty()) { + buffer = item_factory_(); + if (!buffer) return nullptr; + } else { + buffer = std::move(available_.back()); + available_.pop_back(); + reuse = true; + } + + ++in_use_count_; + } + + // This needs to wait on consumer sync points, therefore it should not be + // done while holding the mutex. + if (reuse) { + buffer->Reuse(); + } + + // Return a shared_ptr with a custom deleter that adds the buffer back + // to our available list. + std::weak_ptr> weak_pool(this->shared_from_this()); + return std::shared_ptr(buffer.release(), [weak_pool](Item* buf) { + auto pool = weak_pool.lock(); + if (pool) { + pool->Return(absl::WrapUnique(buf)); + } else { + delete buf; + } + }); +} + +template +inline std::pair ReusablePool::GetInUseAndAvailableCounts() { + absl::MutexLock lock(&mutex_); + return {in_use_count_, available_.size()}; +} + +template +void ReusablePool::Return(std::unique_ptr buf) { + std::vector> trimmed; + { + absl::MutexLock lock(&mutex_); + --in_use_count_; + available_.emplace_back(std::move(buf)); + TrimAvailable(&trimmed); + } + // The trimmed buffers will be released without holding the lock. +} + +template +void ReusablePool::TrimAvailable( + std::vector>* trimmed) { + int keep = std::max(keep_count_ - in_use_count_, 0); + if (available_.size() > keep) { + auto trim_it = std::next(available_.begin(), keep); + if (trimmed) { + std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); + } + available_.erase(trim_it, available_.end()); + } +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_REUSABLE_POOL_H_ From 1beca6165057a4d198a09cfb9becca9252529895 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:06:16 -0800 Subject: [PATCH 046/469] Register GlTextureBuffer pool with GpuBuffer First crack at hooking up pools with the GpuBufferStorage system. Will most likely be superseded later, but for now this works with minimal code impact: just overwrite the factory for a storage type with one that uses the pool. PiperOrigin-RevId: 488783854 --- mediapipe/gpu/gpu_buffer_storage.h | 20 +++++++++++---- mediapipe/gpu/gpu_shared_data_internal.cc | 30 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 3d872eb66..214f506c0 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -74,13 +74,17 @@ class GpuBufferStorageRegistry { template RegistryToken Register() { - return Register( + return RegisterFactory( [](int width, int height, GpuBufferFormat format) -> std::shared_ptr { return CreateStorage(overload_priority<10>{}, width, height, format); - }, - Storage::GetProviderTypes()); + }); + } + + template + RegistryToken RegisterFactory(F&& factory) { + return Register(factory, Storage::GetProviderTypes()); } template @@ -148,6 +152,13 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { return kHashes; } + // Exposing this as a function allows dependent initializers to call this to + // ensure proper ordering. + static GpuBufferStorageRegistry::RegistryToken RegisterOnce() { + static auto registration = GpuBufferStorageRegistry::Get().Register(); + return registration; + } + private: virtual const void* down_cast(TypeId to) const override { return down_cast_impl(to, types{}); @@ -161,8 +172,7 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { return down_cast_impl(to, types{}); } - inline static auto registration = - GpuBufferStorageRegistry::Get().Register(); + inline static auto registration = RegisterOnce(); using RequireStatics = ForceStaticInstantiation<®istration>; }; diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 52db88633..91723a7d1 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -200,4 +200,34 @@ GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } #endif // __APPLE__ +extern const GraphService kGpuService; + +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +static std::shared_ptr GetGlTextureBufferFromPool( + int width, int height, GpuBufferFormat format) { + std::shared_ptr texture_buffer; + const auto cc = LegacyCalculatorSupport::Scoped::current(); + + if (cc && cc->Service(kGpuService).IsAvailable()) { + GpuBufferMultiPool* pool = + &cc->Service(kGpuService).GetObject().gpu_buffer_pool(); + // Note that the "gpu_buffer_pool" serves GlTextureBuffers on non-Apple + // platforms. TODO: refactor into storage pools. + texture_buffer = pool->GetBuffer(width, height, format) + .internal_storage(); + } else { + texture_buffer = GlTextureBuffer::Create(width, height, format); + } + return texture_buffer; +} + +static auto kGlTextureBufferPoolRegistration = [] { + // Ensure that the GlTextureBuffer's own factory is already registered, so we + // can override it. + GlTextureBuffer::RegisterOnce(); + return internal::GpuBufferStorageRegistry::Get() + .RegisterFactory(GetGlTextureBufferFromPool); +}(); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } // namespace mediapipe From 7e19bbe35c85e77ba1d99a9824ecb60d06869f52 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:57:46 -0800 Subject: [PATCH 047/469] Internal change PiperOrigin-RevId: 488795920 --- mediapipe/gpu/gl_texture_buffer.h | 4 ++++ mediapipe/gpu/gpu_buffer_storage.h | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index a770163b5..1be24a86b 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -143,6 +143,10 @@ class GlTextureBuffer return producer_context_; } +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + static constexpr bool kDisableGpuBufferRegistration = true; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + private: // Creates a texture of dimensions width x height and allocates space for it. // If data is provided, it is uploaded to the texture; otherwise, it can be diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 214f506c0..0da5f236a 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -84,11 +84,17 @@ class GpuBufferStorageRegistry { template RegistryToken RegisterFactory(F&& factory) { + if constexpr (kDisableRegistration) { + return {}; + } return Register(factory, Storage::GetProviderTypes()); } template RegistryToken RegisterConverter(F&& converter) { + if constexpr (kDisableRegistration) { + return {}; + } return Register( [converter](std::shared_ptr source) -> std::shared_ptr { @@ -119,6 +125,13 @@ class GpuBufferStorageRegistry { return std::make_shared(args...); } + // Temporary workaround: a Storage class can define a static constexpr + // kDisableGpuBufferRegistration member to true to prevent registering any + // factory of converter that would produce it. + // TODO: better solution for storage priorities. + template + static constexpr bool kDisableRegistration = false; + RegistryToken Register(StorageFactory factory, std::vector provider_hashes); RegistryToken Register(StorageConverter converter, @@ -130,6 +143,13 @@ class GpuBufferStorageRegistry { converter_for_view_provider_and_existing_storage_; }; +// Putting this outside the class body to work around a GCC bug. +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=71954 +template +constexpr bool GpuBufferStorageRegistry::kDisableRegistration< + Storage, std::void_t> = + Storage::kDisableGpuBufferRegistration; + // Defining a member of this type causes P to be ODR-used, which forces its // instantiation if it's a static member of a template. template From 6702ef3d57570e66101a7e4535a04b0a75cdb6bb Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Tue, 15 Nov 2022 16:58:38 -0800 Subject: [PATCH 048/469] Internal change PiperOrigin-RevId: 488796090 --- docs/BUILD | 1 + docs/build_java_api_docs.py | 33 ++++++++++++++++----------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/BUILD b/docs/BUILD index ad08df66a..8e85dbf86 100644 --- a/docs/BUILD +++ b/docs/BUILD @@ -17,6 +17,7 @@ py_binary( name = "build_java_api_docs", srcs = ["build_java_api_docs.py"], data = [ + "//third_party/android/sdk:api/26.txt", "//third_party/java/doclava/current:doclava.jar", "//third_party/java/jsilver:jsilver_jar", ], diff --git a/docs/build_java_api_docs.py b/docs/build_java_api_docs.py index e96e1fd83..b13e8d1df 100644 --- a/docs/build_java_api_docs.py +++ b/docs/build_java_api_docs.py @@ -20,10 +20,6 @@ from absl import flags from tensorflow_docs.api_generator import gen_java -_JAVA_ROOT = flags.DEFINE_string('java_src', None, - 'Override the Java source path.', - required=False) - _OUT_DIR = flags.DEFINE_string('output_dir', '/tmp/mp_java/', 'Write docs here.') @@ -37,27 +33,30 @@ _ = flags.DEFINE_bool( 'search_hints', True, '[UNUSED] Include metadata search hints in the generated files') +_ANDROID_SDK = pathlib.Path('android/sdk/api/26.txt') + def main(_) -> None: - if not (java_root := _JAVA_ROOT.value): - # Default to using a relative path to find the Java source. - mp_root = pathlib.Path(__file__) - while (mp_root := mp_root.parent).name != 'mediapipe': - # Find the nearest `mediapipe` dir. - pass + # Default to using a relative path to find the Java source. + mp_root = pathlib.Path(__file__) + while (mp_root := mp_root.parent).name != 'mediapipe': + # Find the nearest `mediapipe` dir. + pass - # Externally, parts of the repo are nested inside a mediapipe/ directory - # that does not exist internally. Support both. - if (mp_root / 'mediapipe').exists(): - mp_root = mp_root / 'mediapipe' + # Find the root from which all packages are relative. + root = mp_root.parent - java_root = mp_root / 'tasks/java' + # Externally, parts of the repo are nested inside a mediapipe/ directory + # that does not exist internally. Support both. + if (mp_root / 'mediapipe').exists(): + mp_root = mp_root / 'mediapipe' gen_java.gen_java_docs( package='com.google.mediapipe', - source_path=pathlib.Path(java_root), + source_path=mp_root / 'tasks/java', output_dir=pathlib.Path(_OUT_DIR.value), - site_path=pathlib.Path(_SITE_PATH.value)) + site_path=pathlib.Path(_SITE_PATH.value), + federated_docs={'https://developer.android.com': root / _ANDROID_SDK}) if __name__ == '__main__': From 77b3edbb6757f6afe3446bd237297d62dc14832d Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 17:04:39 -0800 Subject: [PATCH 049/469] Internal change PiperOrigin-RevId: 488797407 --- mediapipe/gpu/gpu_buffer.cc | 47 +++++++++++++++++++++++-------------- mediapipe/gpu/gpu_buffer.h | 14 +++++++---- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index e570ce8ba..35a73fd8f 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -1,6 +1,7 @@ #include "mediapipe/gpu/gpu_buffer.h" #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -29,7 +30,7 @@ std::string GpuBuffer::DebugString() const { "]"); } -internal::GpuBufferStorage& GpuBuffer::GetStorageForView( +internal::GpuBufferStorage* GpuBuffer::GetStorageForView( TypeId view_provider_type, bool for_writing) const { const std::shared_ptr* chosen_storage = nullptr; @@ -45,38 +46,48 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView( // TODO: choose best conversion. if (!chosen_storage) { for (const auto& s : storages_) { - auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider(view_provider_type, - s->storage_type()); - if (converter) { - storages_.push_back(converter(s)); - chosen_storage = &storages_.back(); + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + if (auto new_storage = converter(s)) { + storages_.push_back(new_storage); + chosen_storage = &storages_.back(); + break; + } } } } if (for_writing) { - if (!chosen_storage) { - // Allocate a new storage supporting the requested view. - auto factory = internal::GpuBufferStorageRegistry::Get() - .StorageFactoryForViewProvider(view_provider_type); - if (factory) { - storages_ = {factory(width(), height(), format())}; - chosen_storage = &storages_.back(); - } - } else { + if (chosen_storage) { // Discard all other storages. storages_ = {*chosen_storage}; chosen_storage = &storages_.back(); + } else { + // Allocate a new storage supporting the requested view. + if (auto factory = + internal::GpuBufferStorageRegistry::Get() + .StorageFactoryForViewProvider(view_provider_type)) { + if (auto new_storage = factory(width(), height(), format())) { + storages_ = {std::move(new_storage)}; + chosen_storage = &storages_.back(); + } + } } } + return chosen_storage ? chosen_storage->get() : nullptr; +} +internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( + TypeId view_provider_type, bool for_writing) const { + auto* chosen_storage = + GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " << absl::StrJoin(storages_, ", ", StorageTypeFormatter()); - DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); - return **chosen_storage; + DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); + return *chosen_storage; } #if !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 57e077151..ad5c130b5 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -105,7 +105,7 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetReadView(Args... args) const { - return GetViewProvider(false)->GetReadView( + return GetViewProviderOrDie(false).GetReadView( internal::types{}, std::make_shared(*this), std::forward(args)...); } @@ -114,7 +114,7 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetWriteView(Args... args) { - return GetViewProvider(true)->GetWriteView( + return GetViewProviderOrDie(true).GetWriteView( internal::types{}, std::make_shared(*this), std::forward(args)...); } @@ -147,13 +147,17 @@ class GpuBuffer { GpuBufferFormat format_ = GpuBufferFormat::kUnknown; }; - internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type, + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, bool for_writing) const; + internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, + bool for_writing) const; + template - internal::ViewProvider* GetViewProvider(bool for_writing) const { + internal::ViewProvider& GetViewProviderOrDie(bool for_writing) const { using VP = internal::ViewProvider; - return GetStorageForView(kTypeId, for_writing).template down_cast(); + return *GetStorageForViewOrDie(kTypeId, for_writing) + .template down_cast(); } std::shared_ptr& no_storage() const { From 4bda012bba8fa7b5e0b4a04ebdfae8519329bc32 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 17:07:26 -0800 Subject: [PATCH 050/469] Factor out gl_texture_util PiperOrigin-RevId: 488797985 --- mediapipe/gpu/BUILD | 11 +++++++ mediapipe/gpu/gl_texture_util.cc | 30 ++++++++++++++++++ mediapipe/gpu/gl_texture_util.h | 34 +++++++++++++++++++++ mediapipe/gpu/gpu_buffer_test.cc | 52 ++++---------------------------- 4 files changed, 81 insertions(+), 46 deletions(-) create mode 100644 mediapipe/gpu/gl_texture_util.cc create mode 100644 mediapipe/gpu/gl_texture_util.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 747d131ba..68e788c52 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -689,6 +689,17 @@ cc_library( }), ) +cc_library( + name = "gl_texture_util", + srcs = ["gl_texture_util.cc"], + hdrs = ["gl_texture_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":gl_base", + ":gl_texture_view", + ], +) + cc_library( name = "shader_util", srcs = ["shader_util.cc"], diff --git a/mediapipe/gpu/gl_texture_util.cc b/mediapipe/gpu/gl_texture_util.cc new file mode 100644 index 000000000..603e82a46 --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.cc @@ -0,0 +1,30 @@ +#include "mediapipe/gpu/gl_texture_util.h" + +namespace mediapipe { + +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { + glViewport(0, 0, src.width(), src.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), + src.name(), 0); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(dst.target(), dst.name()); + glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); + + glBindTexture(dst.target(), 0); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, + 0); +} + +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, + float a) { + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), + view.name(), 0); + glClearColor(r, g, b, a); + glClear(GL_COLOR_BUFFER_BIT); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, + 0); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_util.h b/mediapipe/gpu/gl_texture_util.h new file mode 100644 index 000000000..73ac37ade --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.h @@ -0,0 +1,34 @@ +#ifndef MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ +#define MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ + +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gl_texture_view.h" + +namespace mediapipe { + +// Copies a texture to another. +// Assumes a framebuffer is already set up +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst); + +// Fills a texture with a color. +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, float a); + +// RAII class to set up a temporary framebuffer. Mainly for test use. +class TempGlFramebuffer { + public: + TempGlFramebuffer() { + glGenFramebuffers(1, &framebuffer_); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + } + ~TempGlFramebuffer() { + glBindFramebuffer(GL_FRAMEBUFFER, 0); + glDeleteFramebuffers(1, &framebuffer_); + } + + private: + GLuint framebuffer_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 3fd519b21..796cb1d9d 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_test_base.h" @@ -41,47 +42,6 @@ void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) { } } -// Assumes a framebuffer is already set up -void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { - glViewport(0, 0, src.width(), src.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), - src.name(), 0); - - glActiveTexture(GL_TEXTURE0); - glBindTexture(dst.target(), dst.name()); - glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); - - glBindTexture(dst.target(), 0); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, - 0); -} - -void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, - float a) { - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); - glClearColor(r, g, b, a); - glClear(GL_COLOR_BUFFER_BIT); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, - 0); -} - -class TempGlFramebuffer { - public: - TempGlFramebuffer() { - glGenFramebuffers(1, &framebuffer_); - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - } - ~TempGlFramebuffer() { - glBindFramebuffer(GL_FRAMEBUFFER, 0); - glDeleteFramebuffers(1, &framebuffer_); - } - - private: - GLuint framebuffer_; -}; - class GpuBufferTest : public GpuTestBase {}; TEST_F(GpuBufferTest, BasicTest) { @@ -127,7 +87,7 @@ TEST_F(GpuBufferTest, GlTextureView) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view")); } @@ -162,7 +122,7 @@ TEST_F(GpuBufferTest, ImageFrame) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view")); } @@ -196,7 +156,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view")); } @@ -230,7 +190,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame green(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(green, 0, 255, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, green, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view")); } @@ -240,7 +200,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame blue(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(blue, 0, 0, 255, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, blue, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view")); } From b308c0dd5e114cbf803dd2864f67589be048b7a0 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 17:08:37 -0800 Subject: [PATCH 051/469] Implement CVPixelBufferRef access as a view. PiperOrigin-RevId: 488798216 --- mediapipe/gpu/gpu_buffer.cc | 7 ++-- mediapipe/gpu/gpu_buffer.h | 4 +++ .../gpu/gpu_buffer_storage_cv_pixel_buffer.h | 35 ++++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index 35a73fd8f..388960b11 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -93,8 +93,11 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer) { - auto p = buffer.internal_storage(); - if (p) return **p; + if (buffer.GetStorageForView( + kTypeId>, + /*for_writing=*/false) != nullptr) { + return *buffer.GetReadView(); + } return nullptr; } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index ad5c130b5..45146a322 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -179,6 +179,10 @@ class GpuBuffer { // This is mutable because view methods that do not change the contents may // still need to allocate new storages. mutable std::vector> storages_; + +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index 017771dc7..e5bc5de43 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -12,10 +12,27 @@ namespace mediapipe { class GlContext; +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual CFHolder GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const = 0; + virtual CFHolder GetWriteView( + internal::types, + std::shared_ptr gpu_buffer) = 0; +}; + +} // namespace internal + class GpuBufferStorageCvPixelBuffer : public internal::GpuBufferStorageImpl< GpuBufferStorageCvPixelBuffer, internal::ViewProvider, - internal::ViewProvider>, + internal::ViewProvider, + internal::ViewProvider>, public CFHolder { public: using CFHolder::CFHolder; @@ -44,6 +61,12 @@ class GpuBufferStorageCvPixelBuffer std::shared_ptr GetWriteView( internal::types, std::shared_ptr gpu_buffer) override; + CFHolder GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const override; + CFHolder GetWriteView( + internal::types, + std::shared_ptr gpu_buffer) override; private: GlTextureView GetTexture(std::shared_ptr gpu_buffer, int plane, @@ -51,6 +74,16 @@ class GpuBufferStorageCvPixelBuffer void ViewDoneWriting(const GlTextureView& view); }; +inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const { + return *this; +} +inline CFHolder GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types, std::shared_ptr gpu_buffer) { + return *this; +} + namespace internal { // These functions enable backward-compatible construction of a GpuBuffer from // CVPixelBufferRef without having to expose that type in the main GpuBuffer From 2f77bf44e3f3a53ff187bd9a39f9cbc413b4e413 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 18:08:31 -0800 Subject: [PATCH 052/469] Use train_data to evaluate accuracy of unit test for gesture_recognizer due to limited dataset size. PiperOrigin-RevId: 488808942 --- .../gesture_recognizer_test.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 7e7a1ca30..9bac22133 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -42,8 +42,8 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() all_data = self._load_data() - # Splits data, 90% data for training, 10% for testing - self._train_data, self._test_data = all_data.split(0.9) + # Splits data, 90% data for training, 10% for validation + self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): model_options = gesture_recognizer.ModelOptions() @@ -53,7 +53,7 @@ class GestureRecognizerTest(tf.test.TestCase): model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) self._test_accuracy(model) @@ -66,7 +66,7 @@ class GestureRecognizerTest(tf.test.TestCase): model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) model.export_model() model_bundle_file = os.path.join(hparams.export_dir, @@ -94,8 +94,9 @@ class GestureRecognizerTest(tf.test.TestCase): size=[1, model.embedding_size]) def _test_accuracy(self, model, threshold=0.5): - _, accuracy = model.evaluate(self._test_data) - tf.compat.v1.logging.info(f'accuracy: {accuracy}') + # Test on _train_data because of our limited dataset size + _, accuracy = model.evaluate(self._train_data) + tf.compat.v1.logging.info(f'train accuracy: {accuracy}') self.assertGreaterEqual(accuracy, threshold) @unittest_mock.patch.object( @@ -113,7 +114,7 @@ class GestureRecognizerTest(tf.test.TestCase): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=options) mock_hparams.assert_called_once() mock_model_options.assert_called_once() @@ -128,11 +129,11 @@ class GestureRecognizerTest(tf.test.TestCase): with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) self._test_accuracy(model) From fe66de37149bbd8a706b78e33b210bde5c3a021c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:30:58 -0800 Subject: [PATCH 053/469] Internal change PiperOrigin-RevId: 488812677 --- mediapipe/gpu/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 68e788c52..9cb27d2f1 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -221,11 +221,11 @@ cc_library( ":gpu_buffer_format", ":gpu_buffer_storage", ":gpu_buffer_storage_image_frame", + "@com_google_absl//absl/memory", # TODO: remove this dependency. Some other teams' tests # depend on having an indirect image_frame dependency, need to be # fixed first. "//mediapipe/framework/formats:image_frame", - "@com_google_absl//absl/memory", ], ) From 4c874fe4cd7f8c2fe4afdf1ac7630450264c3eba Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:31:27 -0800 Subject: [PATCH 054/469] Allow conversion of GlTextureBuffer to CVPixelBufferRef This means that, if an iOS application sends in a GlTextureBuffer but expects a CVPixelBufferRef as output, everything will work even if the graph just forwards the same input. Also, access by Metal calculators will also work transparently. PiperOrigin-RevId: 488812748 --- mediapipe/gpu/BUILD | 27 ++++++++++++++++++++++++++- mediapipe/gpu/gl_texture_buffer.cc | 29 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9cb27d2f1..196de3076 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -226,7 +226,13 @@ cc_library( # depend on having an indirect image_frame dependency, need to be # fixed first. "//mediapipe/framework/formats:image_frame", - ], + ] + select({ + "//conditions:default": [], + ":platform_ios_with_gpu": [ + ":gl_texture_util", + ":gpu_buffer_storage_cv_pixel_buffer", + ], + }), ) cc_library( @@ -344,6 +350,25 @@ cc_library( ], ) +mediapipe_cc_test( + name = "gpu_buffer_storage_cv_pixel_buffer_test", + size = "small", + timeout = "moderate", + srcs = ["gpu_buffer_storage_cv_pixel_buffer_test.cc"], + platforms = ["ios"], + deps = [ + ":gl_texture_buffer", + ":gl_texture_util", + ":gpu_buffer", + ":gpu_buffer_storage_cv_pixel_buffer", + ":gpu_test_base", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:test_util", + "//mediapipe/objc:util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "cv_texture_cache_manager", srcs = ["cv_texture_cache_manager.cc"], diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index fbb91a8f5..4c2f15a8d 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -18,6 +18,11 @@ #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/gl_texture_util.h" +#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h" +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + namespace mediapipe { std::unique_ptr GlTextureBuffer::Wrap( @@ -380,4 +385,28 @@ static auto kConverterRegistration2 = .RegisterConverter( ConvertFromImageFrame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +static std::shared_ptr ConvertToCvPixelBuffer( + std::shared_ptr buf) { + auto output = absl::make_unique( + buf->width(), buf->height(), buf->format()); + buf->GetProducerContext()->Run([buf, &output] { + TempGlFramebuffer framebuffer; + auto src = buf->GetReadView(internal::types{}, nullptr, 0); + auto dst = + output->GetWriteView(internal::types{}, nullptr, 0); + CopyGlTexture(src, dst); + glFlush(); + }); + return output; +} + +static auto kConverterRegistrationCvpb = + internal::GpuBufferStorageRegistry::Get() + .RegisterConverter( + ConvertToCvPixelBuffer); + +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } // namespace mediapipe From 767cc2ee3cbec8472fcacedbd890def1d9c0b63f Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:31:54 -0800 Subject: [PATCH 055/469] More comments on gpu_buffer_storage This gives a basic explanation of the role of storages and views, and provides some details on how to implement a new storage type. PiperOrigin-RevId: 488812807 --- mediapipe/gpu/gpu_buffer_storage.h | 45 +++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 0da5f236a..b15c9c843 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -22,13 +22,27 @@ struct types {}; template class ViewProvider; -// Interface for a backing storage for GpuBuffer. +// Generic interface for a backing storage for GpuBuffer. +// +// GpuBuffer is an opaque handle to an image. Its contents are handled by +// Storage classes. Application code does not interact with the storages +// directly; to access the data, it asks the GpuBuffer for a View, and in turn +// GpuBuffer looks for a storage that can provide that view. +// This architecture decouples application code from the underlying storage, +// making it possible to use platform-specific optimized storage systems, e.g. +// for zero-copy data sharing between CPU and GPU. +// +// Storage implementations should inherit from GpuBufferStorageImpl. See that +// class for details. class GpuBufferStorage { public: virtual ~GpuBufferStorage() = default; + + // Concrete storage types should override the following three accessors. virtual int width() const = 0; virtual int height() const = 0; virtual GpuBufferFormat format() const = 0; + // We can't use dynamic_cast since we want to support building without RTTI. // The public methods delegate to the type-erased private virtual method. template @@ -72,6 +86,8 @@ class GpuBufferStorageRegistry { return *registry; } + // Registers a storage type by automatically creating a factory for it. + // This is normally called by GpuBufferImpl. template RegistryToken Register() { return RegisterFactory( @@ -82,6 +98,7 @@ class GpuBufferStorageRegistry { }); } + // Registers a new factory for a storage type. template RegistryToken RegisterFactory(F&& factory) { if constexpr (kDisableRegistration) { @@ -90,6 +107,7 @@ class GpuBufferStorageRegistry { return Register(factory, Storage::GetProviderTypes()); } + // Registers a new converter from storage type StorageFrom to StorageTo. template RegistryToken RegisterConverter(F&& converter) { if constexpr (kDisableRegistration) { @@ -162,14 +180,26 @@ struct ForceStaticInstantiation { #endif // _MSC_VER }; -// T: storage type -// U...: ViewProvider +// Inherit from this class to define a new storage type. The storage type itself +// should be passed as the first template argument (CRTP), followed by one or +// more specializations of ViewProvider. +// +// Concrete storage types should implement the basic accessors from +// GpuBufferStorage, plus the view read/write getters for each ViewProvider they +// implement. This class handles the rest. +// +// Arguments: +// T: storage type +// U...: ViewProvider +// Example: +// class MyStorage : public GpuBufferStorageImpl< +// MyStorage, ViewProvider> template class GpuBufferStorageImpl : public GpuBufferStorage, public U... { public: static const std::vector& GetProviderTypes() { - static std::vector kHashes{kTypeId...}; - return kHashes; + static std::vector kProviderIds{kTypeId...}; + return kProviderIds; } // Exposing this as a function allows dependent initializers to call this to @@ -180,10 +210,11 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { } private: - virtual const void* down_cast(TypeId to) const override { + // Allows a down_cast to any of the view provider types in U. + const void* down_cast(TypeId to) const final { return down_cast_impl(to, types{}); } - TypeId storage_type() const override { return kTypeId; } + TypeId storage_type() const final { return kTypeId; } const void* down_cast_impl(TypeId to, types<>) const { return nullptr; } template From 1c0a1d0aab81bdea369ef912f3f0739cfe84ad81 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:32:27 -0800 Subject: [PATCH 056/469] Remove shared_ptr member from GlTextureView This only exists to support GlTexture's GetFrame API. It can be moved into GlTexture. PiperOrigin-RevId: 488812896 --- mediapipe/gpu/gl_calculator_helper.h | 8 ++++++-- mediapipe/gpu/gl_calculator_helper_impl_common.cc | 15 +++------------ mediapipe/gpu/gl_texture_view.cc | 1 - mediapipe/gpu/gl_texture_view.h | 4 ---- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index e44523202..0a0cc16cb 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -201,9 +201,13 @@ class GlTexture { void Release() { view_ = std::make_shared(); } private: - explicit GlTexture(GlTextureView view) - : view_(std::make_shared(std::move(view))) {} + explicit GlTexture(GlTextureView view, GpuBuffer gpu_buffer) + : gpu_buffer_(std::move(gpu_buffer)), + view_(std::make_shared(std::move(view))) {} friend class GlCalculatorHelperImpl; + // We store the GpuBuffer to support GetFrame, and to ensure that the storage + // outlives the view. + GpuBuffer gpu_buffer_; std::shared_ptr view_; }; diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc index c5c028d4f..6311d8905 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ b/mediapipe/gpu/gl_calculator_helper_impl_common.cc @@ -101,7 +101,7 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, glBindTexture(view.target(), 0); } - return GlTexture(std::move(view)); + return GlTexture(std::move(view), gpu_buffer); } GlTexture GlCalculatorHelperImpl::CreateSourceTexture( @@ -143,7 +143,7 @@ template <> std::unique_ptr GlTexture::GetFrame() const { view_->DoneWriting(); std::shared_ptr view = - view_->gpu_buffer().GetReadView(); + gpu_buffer_.GetReadView(); auto copy = absl::make_unique(); copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); return copy; @@ -151,17 +151,8 @@ std::unique_ptr GlTexture::GetFrame() const { template <> std::unique_ptr GlTexture::GetFrame() const { - auto gpu_buffer = view_->gpu_buffer(); -#ifdef __EMSCRIPTEN__ - // When WebGL is used, the GL context may be spontaneously lost which can - // cause GpuBuffer allocations to fail. In that case, return a dummy buffer - // to allow processing of the current frame complete. - if (!gpu_buffer) { - return std::make_unique(); - } -#endif // __EMSCRIPTEN__ view_->DoneWriting(); - return absl::make_unique(gpu_buffer); + return absl::make_unique(gpu_buffer_); } GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( diff --git a/mediapipe/gpu/gl_texture_view.cc b/mediapipe/gpu/gl_texture_view.cc index 5d1862ddc..cae4039a4 100644 --- a/mediapipe/gpu/gl_texture_view.cc +++ b/mediapipe/gpu/gl_texture_view.cc @@ -7,7 +7,6 @@ void GlTextureView::Release() { if (detach_) detach_(*this); detach_ = nullptr; gl_context_ = nullptr; - gpu_buffer_ = nullptr; plane_ = 0; name_ = 0; width_ = 0; diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index 8b47d620b..d6734ed71 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -43,7 +43,6 @@ class GlTextureView { name_ = other.name_; width_ = other.width_; height_ = other.height_; - gpu_buffer_ = std::move(other.gpu_buffer_); plane_ = other.plane_; detach_ = std::exchange(other.detach_, nullptr); done_writing_ = std::exchange(other.done_writing_, nullptr); @@ -55,7 +54,6 @@ class GlTextureView { int height() const { return height_; } GLenum target() const { return target_; } GLuint name() const { return name_; } - const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; } int plane() const { return plane_; } using DetachFn = std::function; @@ -74,7 +72,6 @@ class GlTextureView { name_(name), width_(width), height_(height), - gpu_buffer_(std::move(gpu_buffer)), plane_(plane), detach_(std::move(detach)), done_writing_(std::move(done_writing)) {} @@ -93,7 +90,6 @@ class GlTextureView { // Note: when scale is not 1, we still give the nominal size of the image. int width_ = 0; int height_ = 0; - std::shared_ptr gpu_buffer_; // using shared_ptr temporarily int plane_ = 0; DetachFn detach_; mutable DoneWritingFn done_writing_; From 13b4b825d74672d69a69d501dac4caf41e3ed098 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:33:04 -0800 Subject: [PATCH 057/469] Remove std::shared_ptr argument from GetRead/WriteView PiperOrigin-RevId: 488813004 --- mediapipe/gpu/gl_texture_buffer.cc | 23 +++++++--------- mediapipe/gpu/gl_texture_buffer.h | 2 -- mediapipe/gpu/gl_texture_view.h | 12 +++------ mediapipe/gpu/gpu_buffer.h | 6 ++--- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 24 +++++++---------- .../gpu/gpu_buffer_storage_cv_pixel_buffer.h | 27 +++++++------------ .../gpu/gpu_buffer_storage_image_frame.h | 6 ++--- mediapipe/gpu/image_frame_view.h | 5 ++-- 8 files changed, 38 insertions(+), 67 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 4c2f15a8d..e57195a46 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -255,9 +255,8 @@ void GlTextureBuffer::WaitForConsumersOnGpu() { // precisely, on only one GL context. } -GlTextureView GlTextureBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { +GlTextureView GlTextureBuffer::GetReadView(internal::types, + int plane) const { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); @@ -269,13 +268,11 @@ GlTextureView GlTextureBuffer::GetReadView( DidRead(texture.gl_context()->CreateSyncToken()); }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, std::move(detach), - nullptr); + plane, std::move(detach), nullptr); } -GlTextureView GlTextureBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { +GlTextureView GlTextureBuffer::GetWriteView(internal::types, + int plane) { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); @@ -286,8 +283,7 @@ GlTextureView GlTextureBuffer::GetWriteView( GlTextureView::DoneWritingFn done_writing = [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, nullptr, - std::move(done_writing)); + plane, nullptr, std::move(done_writing)); } void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { @@ -364,7 +360,7 @@ static std::shared_ptr ConvertToImageFrame( absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); buf->GetProducerContext()->Run([buf, &output] { - auto view = buf->GetReadView(internal::types{}, nullptr, 0); + auto view = buf->GetReadView(internal::types{}, 0); ReadTexture(view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); @@ -393,9 +389,8 @@ static std::shared_ptr ConvertToCvPixelBuffer( buf->width(), buf->height(), buf->format()); buf->GetProducerContext()->Run([buf, &output] { TempGlFramebuffer framebuffer; - auto src = buf->GetReadView(internal::types{}, nullptr, 0); - auto dst = - output->GetWriteView(internal::types{}, nullptr, 0); + auto src = buf->GetReadView(internal::types{}, 0); + auto dst = output->GetWriteView(internal::types{}, 0); CopyGlTexture(src, dst); glFlush(); }); diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index 1be24a86b..c7643fd1b 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -95,10 +95,8 @@ class GlTextureBuffer GpuBufferFormat format() const { return format_; } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; // If this texture is going to be used outside of the context that produced diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index d6734ed71..b8ead2708 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -65,8 +65,8 @@ class GlTextureView { friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageAhwb; GlTextureView(GlContext* context, GLenum target, GLuint name, int width, - int height, std::shared_ptr gpu_buffer, int plane, - DetachFn detach, DoneWritingFn done_writing) + int height, int plane, DetachFn detach, + DoneWritingFn done_writing) : gl_context_(context), target_(target), name_(name), @@ -108,12 +108,8 @@ class ViewProvider { // the same view implement the same signature. // Note that we allow different views to have custom signatures, providing // additional view-specific arguments that may be needed. - virtual GlTextureView GetReadView(types, - std::shared_ptr gpu_buffer, - int plane) const = 0; - virtual GlTextureView GetWriteView(types, - std::shared_ptr gpu_buffer, - int plane) = 0; + virtual GlTextureView GetReadView(types, int plane) const = 0; + virtual GlTextureView GetWriteView(types, int plane) = 0; }; } // namespace internal diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 45146a322..56507d92f 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -106,8 +106,7 @@ class GpuBuffer { template decltype(auto) GetReadView(Args... args) const { return GetViewProviderOrDie(false).GetReadView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + internal::types{}, std::forward(args)...); } // Gets a write view of the specified type. The arguments depend on the @@ -115,8 +114,7 @@ class GpuBuffer { template decltype(auto) GetWriteView(Args... args) { return GetViewProviderOrDie(true).GetWriteView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + internal::types{}, std::forward(args)...); } // Attempts to access an underlying storage object of the specified type. diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index d68ac0db0..f3954a6e4 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -26,8 +26,7 @@ GpuBufferStorageCvPixelBuffer::GpuBufferStorageCvPixelBuffer( } GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( - std::shared_ptr gpu_buffer, int plane, - GlTextureView::DoneWritingFn done_writing) const { + int plane, GlTextureView::DoneWritingFn done_writing) const { CVReturn err; auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); @@ -60,33 +59,30 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( cv_texture.adopt(cv_texture_temp); return GlTextureView( gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture), - CVOpenGLESTextureGetName(*cv_texture), width(), height(), - std::move(gpu_buffer), plane, + CVOpenGLESTextureGetName(*cv_texture), width(), height(), plane, [cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ }, done_writing); #endif // TARGET_OS_OSX } GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { - return GetTexture(std::move(gpu_buffer), plane, nullptr); + internal::types, int plane) const { + return GetTexture(plane, nullptr); } GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { - return GetTexture( - std::move(gpu_buffer), plane, - [this](const mediapipe::GlTextureView& view) { ViewDoneWriting(view); }); + internal::types, int plane) { + return GetTexture(plane, [this](const mediapipe::GlTextureView& view) { + ViewDoneWriting(view); + }); } std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer) const { + internal::types) const { return CreateImageFrameForCVPixelBuffer(**this); } std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer) { + internal::types) { return CreateImageFrameForCVPixelBuffer(**this); } diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index e5bc5de43..a9389ab8a 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -19,11 +19,9 @@ class ViewProvider { public: virtual ~ViewProvider() = default; virtual CFHolder GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const = 0; + internal::types) const = 0; virtual CFHolder GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) = 0; + internal::types) = 0; }; } // namespace internal @@ -50,37 +48,30 @@ class GpuBufferStorageCvPixelBuffer CVPixelBufferGetPixelFormatType(**this)); } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override; + internal::types) const override; std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override; + internal::types) override; CFHolder GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override; + internal::types) const override; CFHolder GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override; + internal::types) override; private: - GlTextureView GetTexture(std::shared_ptr gpu_buffer, int plane, + GlTextureView GetTexture(int plane, GlTextureView::DoneWritingFn done_writing) const; void ViewDoneWriting(const GlTextureView& view); }; inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const { + internal::types) const { return *this; } inline CFHolder GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer) { + internal::types) { return *this; } diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h index 2cea3445e..ab547b9ea 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.h +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -29,13 +29,11 @@ class GpuBufferStorageImageFrame std::shared_ptr image_frame() const { return image_frame_; } std::shared_ptr image_frame() { return image_frame_; } std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override { + internal::types) const override { return image_frame_; } std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override { + internal::types) override { return image_frame_; } diff --git a/mediapipe/gpu/image_frame_view.h b/mediapipe/gpu/image_frame_view.h index 2fc6f2495..b7e58a824 100644 --- a/mediapipe/gpu/image_frame_view.h +++ b/mediapipe/gpu/image_frame_view.h @@ -12,9 +12,8 @@ class ViewProvider { public: virtual ~ViewProvider() = default; virtual std::shared_ptr GetReadView( - types, std::shared_ptr gpu_buffer) const = 0; - virtual std::shared_ptr GetWriteView( - types, std::shared_ptr gpu_buffer) = 0; + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; }; } // namespace internal From a28ccb0964e327c5041c40cc26769275b46ce3b7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:33:32 -0800 Subject: [PATCH 058/469] Remove unnecessary forward declarations PiperOrigin-RevId: 488813066 --- mediapipe/gpu/gl_texture_view.h | 3 --- mediapipe/gpu/gpu_buffer_storage.h | 1 - 2 files changed, 4 deletions(-) diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index b8ead2708..8a257cf53 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -25,8 +25,6 @@ namespace mediapipe { class GlContext; -class GlTextureViewManager; -class GpuBuffer; class GlTextureView { public: @@ -60,7 +58,6 @@ class GlTextureView { using DoneWritingFn = std::function; private: - friend class GpuBuffer; friend class GlTextureBuffer; friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageAhwb; diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index b15c9c843..55bb418cf 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -13,7 +13,6 @@ #include "mediapipe/gpu/gpu_buffer_format.h" namespace mediapipe { -class GpuBuffer; namespace internal { template From 8b319e963a4aa46db4f9c9d34c29bdf035f8f9a5 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:34:07 -0800 Subject: [PATCH 059/469] Add comment explaining ViewProvider This was only documented via examples (e.g. ViewProvider), but it's better to explain it properly in the header where the base case is defined. PiperOrigin-RevId: 488813144 --- mediapipe/gpu/gpu_buffer_storage.h | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 55bb418cf..19661d930 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -18,6 +18,28 @@ namespace internal { template struct types {}; +// This template must be specialized for each view type V. Each specialization +// should define a pair of virtual methods called GetReadView and GetWriteView, +// whose first argument is a types tag object. The result type and optional +// further arguments will depend on the view type. +// +// Example: +// template <> +// class ViewProvider { +// public: +// virtual ~ViewProvider() = default; +// virtual MyView GetReadView(types) const = 0; +// virtual MyView GetWriteView(types) = 0; +// }; +// +// The additional arguments and result type are reflected in GpuBuffer's +// GetReadView and GetWriteView methods. +// +// Using a type tag for the first argument allows the methods to be overloaded, +// so that a single storage can implement provider methods for multiple views. +// Since these methods are not template methods, they can (and should) be +// virtual, which allows storage classes to override them, enforcing that all +// storages providing a given view type implement the same interface. template class ViewProvider; From 1979801a92ad5a4d12bf2dd1ae6611c39de3096a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:34:35 -0800 Subject: [PATCH 060/469] Remove GlCalculatorHelperImpl; merge with GlCalculatorHelper Originally, there were multiple implementations of GlCalculatorHelperImpl, depending on the platform and underlying GL APIs. These have all been refactored into other components, and the remaining code in this class is unified and much reduced in size. We can get rid of this implementation detail now. PiperOrigin-RevId: 488813220 --- mediapipe/gpu/BUILD | 2 - mediapipe/gpu/gl_calculator_helper.cc | 163 +++++++++++++---- mediapipe/gpu/gl_calculator_helper.h | 26 ++- mediapipe/gpu/gl_calculator_helper_impl.h | 82 --------- .../gpu/gl_calculator_helper_impl_common.cc | 169 ------------------ 5 files changed, 146 insertions(+), 296 deletions(-) delete mode 100644 mediapipe/gpu/gl_calculator_helper_impl.h delete mode 100644 mediapipe/gpu/gl_calculator_helper_impl_common.cc diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 196de3076..b0c1c22b2 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -749,11 +749,9 @@ cc_library( name = "gl_calculator_helper", srcs = [ "gl_calculator_helper.cc", - "gl_calculator_helper_impl_common.cc", ], hdrs = [ "gl_calculator_helper.h", - "gl_calculator_helper_impl.h", ], linkopts = select({ "//conditions:default": [], diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index ba1423977..7d317e0f1 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -20,18 +20,32 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_service.h" namespace mediapipe { -// The constructor and destructor need to be defined here so that -// std::unique_ptr can see the full definition of GlCalculatorHelperImpl. -// In the header, it is an incomplete type. GlCalculatorHelper::GlCalculatorHelper() {} -GlCalculatorHelper::~GlCalculatorHelper() {} +GlCalculatorHelper::~GlCalculatorHelper() { + if (!Initialized()) return; + RunInGlContext( + [this] { + if (framebuffer_) { + glDeleteFramebuffers(1, &framebuffer_); + framebuffer_ = 0; + } + return absl::OkStatus(); + }, + /*calculator_context=*/nullptr) + .IgnoreError(); +} + +void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, + GpuResources* gpu_resources) { + gpu_resources_ = gpu_resources; + gl_context_ = gpu_resources_->gl_context(cc); +} absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { CHECK(cc); @@ -39,19 +53,16 @@ absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { RET_CHECK(gpu_service.IsAvailable()) << "GPU service not available. Did you forget to call " "GlCalculatorHelper::UpdateContract?"; - // TODO return error from impl_ (needs two-stage init) - impl_ = - absl::make_unique(cc, &gpu_service.GetObject()); + InitializeInternal(cc, &gpu_service.GetObject()); return absl::OkStatus(); } void GlCalculatorHelper::InitializeForTest(GpuSharedData* gpu_shared) { - impl_ = absl::make_unique( - nullptr, gpu_shared->gpu_resources.get()); + InitializeInternal(nullptr, gpu_shared->gpu_resources.get()); } void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { - impl_ = absl::make_unique(nullptr, gpu_resources); + InitializeInternal(nullptr, gpu_resources); } // static @@ -88,44 +99,109 @@ absl::Status GlCalculatorHelper::SetupInputSidePackets( return absl::OkStatus(); } +absl::Status GlCalculatorHelper::RunInGlContext( + std::function gl_func, + CalculatorContext* calculator_context) { + if (calculator_context) { + return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), + calculator_context->InputTimestamp()); + } else { + return gl_context_->Run(std::move(gl_func)); + } +} + absl::Status GlCalculatorHelper::RunInGlContext( std::function gl_func) { - if (!impl_) return absl::InternalError("helper not initialized"); + if (!Initialized()) return absl::InternalError("helper not initialized"); // TODO: Remove LegacyCalculatorSupport from MediaPipe OSS. auto calculator_context = LegacyCalculatorSupport::Scoped::current(); - return impl_->RunInGlContext(gl_func, calculator_context); + return RunInGlContext(gl_func, calculator_context); } -GLuint GlCalculatorHelper::framebuffer() const { return impl_->framebuffer(); } +GLuint GlCalculatorHelper::framebuffer() const { return framebuffer_; } + +void GlCalculatorHelper::CreateFramebuffer() { + // Our framebuffer will have a color attachment but no depth attachment, + // so it's important that the depth test be off. It is disabled by default, + // but we wanted to be explicit. + // TODO: move this to glBindFramebuffer? + glDisable(GL_DEPTH_TEST); + glGenFramebuffers(1, &framebuffer_); +} void GlCalculatorHelper::BindFramebuffer(const GlTexture& dst) { - return impl_->BindFramebuffer(dst); +#ifdef __ANDROID__ + // On (some?) Android devices, attaching a new texture to the frame buffer + // does not seem to detach the old one. As a result, using that texture + // for texturing can produce incorrect output. See b/32091368 for details. + // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 + // or glFramebufferTexture2D with a texture ID of 0. + glBindFramebuffer(GL_FRAMEBUFFER, 0); +#endif + if (!framebuffer_) { + CreateFramebuffer(); + } + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, dst.width(), dst.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), + dst.name(), 0); + +#ifndef NDEBUG + GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + if (status != GL_FRAMEBUFFER_COMPLETE) { + VLOG(2) << "incomplete framebuffer: " << status; + } +#endif } -GlTexture GlCalculatorHelper::CreateSourceTexture( - const GpuBuffer& pixel_buffer) { - return impl_->CreateSourceTexture(pixel_buffer); +GlTexture GlCalculatorHelper::MapGpuBuffer(const GpuBuffer& gpu_buffer, + GlTextureView view) { + if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { + // TODO: do the params need to be reset here?? + glBindTexture(view.target(), view.name()); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + gpu_buffer.format(), view.plane(), GetGlVersion()); + gl_context_->SetStandardTextureParams(view.target(), + info.gl_internal_format); + glBindTexture(view.target(), 0); + } + + return GlTexture(std::move(view), gpu_buffer); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer) { + return CreateSourceTexture(gpu_buffer, 0); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer, + int plane) { + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const ImageFrame& image_frame) { - return impl_->CreateSourceTexture(image_frame); -} - -GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer, - int plane) { - return impl_->CreateSourceTexture(pixel_buffer, plane); + auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); } GpuBuffer GlCalculatorHelper::GpuBufferWithImageFrame( std::shared_ptr image_frame) { - return impl_->GpuBufferWithImageFrame(std::move(image_frame)); + return GpuBuffer( + std::make_shared(std::move(image_frame))); } GpuBuffer GlCalculatorHelper::GpuBufferCopyingImageFrame( const ImageFrame& image_frame) { - return impl_->GpuBufferCopyingImageFrame(image_frame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); + // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only + // deals with absl::Status in MediaPipe OSS. + CHECK_OK(maybe_buffer.status()); + return GpuBuffer(std::move(maybe_buffer).value()); +#else + return GpuBuffer(GlTextureBuffer::Create(image_frame)); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, @@ -136,23 +212,36 @@ void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, *height = pixel_buffer.height(); } -GlTexture GlCalculatorHelper::CreateDestinationTexture(int output_width, - int output_height, +GlTexture GlCalculatorHelper::CreateDestinationTexture(int width, int height, GpuBufferFormat format) { - return impl_->CreateDestinationTexture(output_width, output_height, format); -} + if (!framebuffer_) { + CreateFramebuffer(); + } -GlContext& GlCalculatorHelper::GetGlContext() const { - return impl_->GetGlContext(); -} - -GlVersion GlCalculatorHelper::GetGlVersion() const { - return impl_->GetGlVersion(); + GpuBuffer gpu_buffer = + gpu_resources_->gpu_buffer_pool().GetBuffer(width, height, format); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const mediapipe::Image& image) { - return impl_->CreateSourceTexture(image.GetGpuBuffer()); + return CreateSourceTexture(image.GetGpuBuffer()); +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + std::shared_ptr view = + gpu_buffer_.GetReadView(); + auto copy = absl::make_unique(); + copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); + return copy; +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + return absl::make_unique(gpu_buffer_); } template <> diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index 0a0cc16cb..727be7826 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -33,7 +33,6 @@ namespace mediapipe { -class GlCalculatorHelperImpl; class GlTexture; class GpuResources; struct GpuSharedData; @@ -161,15 +160,30 @@ class GlCalculatorHelper { // TODO: do we need an unbind method too? void BindFramebuffer(const GlTexture& dst); - GlContext& GetGlContext() const; + GlContext& GetGlContext() const { return *gl_context_; } - GlVersion GetGlVersion() const; + GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } // Check if the calculator helper has been previously initialized. - bool Initialized() { return impl_ != nullptr; } + bool Initialized() { return gpu_resources_ != nullptr; } private: - std::unique_ptr impl_; + void InitializeInternal(CalculatorContext* cc, GpuResources* gpu_resources); + + absl::Status RunInGlContext(std::function gl_func, + CalculatorContext* calculator_context); + + // Makes a GpuBuffer accessible as a texture in the GL context. + GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); + + // Create the framebuffer for rendering. + void CreateFramebuffer(); + + std::shared_ptr gl_context_; + + GLuint framebuffer_ = 0; + + GpuResources* gpu_resources_ = nullptr; }; // Represents an OpenGL texture, and is a 'view' into the memory pool. @@ -204,7 +218,7 @@ class GlTexture { explicit GlTexture(GlTextureView view, GpuBuffer gpu_buffer) : gpu_buffer_(std::move(gpu_buffer)), view_(std::make_shared(std::move(view))) {} - friend class GlCalculatorHelperImpl; + friend class GlCalculatorHelper; // We store the GpuBuffer to support GetFrame, and to ensure that the storage // outlives the view. GpuBuffer gpu_buffer_; diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h deleted file mode 100644 index 72b3265fe..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ -#define MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ - -#include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" - -#ifdef __OBJC__ -#import -#import -#endif // __OBJC__ - -#ifdef __ANDROID__ -#include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif - -namespace mediapipe { - -// This class implements the GlCalculatorHelper for iOS and Android. -// See GlCalculatorHelper for details on these methods. -class GlCalculatorHelperImpl { - public: - explicit GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources); - ~GlCalculatorHelperImpl(); - - absl::Status RunInGlContext(std::function gl_func, - CalculatorContext* calculator_context); - - GlTexture CreateSourceTexture(const ImageFrame& image_frame); - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer); - - // Note: multi-plane support is currently only available on iOS. - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer, int plane); - - // Creates a framebuffer and returns the texture that it is bound to. - GlTexture CreateDestinationTexture(int output_width, int output_height, - GpuBufferFormat format); - - GpuBuffer GpuBufferWithImageFrame(std::shared_ptr image_frame); - GpuBuffer GpuBufferCopyingImageFrame(const ImageFrame& image_frame); - - GLuint framebuffer() const { return framebuffer_; } - void BindFramebuffer(const GlTexture& dst); - - GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } - - GlContext& GetGlContext() const; - - // For internal use. - static void ReadTexture(const GlTextureView& view, void* output, size_t size); - - private: - // Makes a GpuBuffer accessible as a texture in the GL context. - GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); - - // Create the framebuffer for rendering. - void CreateFramebuffer(); - - std::shared_ptr gl_context_; - - GLuint framebuffer_ = 0; - - GpuResources& gpu_resources_; -}; - -} // namespace mediapipe - -#endif // MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc deleted file mode 100644 index 6311d8905..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" -#include "mediapipe/gpu/gpu_buffer_format.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -#include "mediapipe/gpu/image_frame_view.h" - -namespace mediapipe { - -GlCalculatorHelperImpl::GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources) - : gpu_resources_(*gpu_resources) { - gl_context_ = gpu_resources_.gl_context(cc); -} - -GlCalculatorHelperImpl::~GlCalculatorHelperImpl() { - RunInGlContext( - [this] { - if (framebuffer_) { - glDeleteFramebuffers(1, &framebuffer_); - framebuffer_ = 0; - } - return absl::OkStatus(); - }, - /*calculator_context=*/nullptr) - .IgnoreError(); -} - -GlContext& GlCalculatorHelperImpl::GetGlContext() const { return *gl_context_; } - -absl::Status GlCalculatorHelperImpl::RunInGlContext( - std::function gl_func, - CalculatorContext* calculator_context) { - if (calculator_context) { - return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), - calculator_context->InputTimestamp()); - } else { - return gl_context_->Run(std::move(gl_func)); - } -} - -void GlCalculatorHelperImpl::CreateFramebuffer() { - // Our framebuffer will have a color attachment but no depth attachment, - // so it's important that the depth test be off. It is disabled by default, - // but we wanted to be explicit. - // TODO: move this to glBindFramebuffer? - glDisable(GL_DEPTH_TEST); - glGenFramebuffers(1, &framebuffer_); -} - -void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) { -#ifdef __ANDROID__ - // On (some?) Android devices, attaching a new texture to the frame buffer - // does not seem to detach the old one. As a result, using that texture - // for texturing can produce incorrect output. See b/32091368 for details. - // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 - // or glFramebufferTexture2D with a texture ID of 0. - glBindFramebuffer(GL_FRAMEBUFFER, 0); -#endif - if (!framebuffer_) { - CreateFramebuffer(); - } - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - glViewport(0, 0, dst.width(), dst.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), - dst.name(), 0); - -#ifndef NDEBUG - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); - if (status != GL_FRAMEBUFFER_COMPLETE) { - VLOG(2) << "incomplete framebuffer: " << status; - } -#endif -} - -GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, - GlTextureView view) { - if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { - // TODO: do the params need to be reset here?? - glBindTexture(view.target(), view.name()); - GlTextureInfo info = GlTextureInfoForGpuBufferFormat( - gpu_buffer.format(), view.plane(), GetGlVersion()); - gl_context_->SetStandardTextureParams(view.target(), - info.gl_internal_format); - glBindTexture(view.target(), 0); - } - - return GlTexture(std::move(view), gpu_buffer); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer) { - return CreateSourceTexture(gpu_buffer, 0); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer, int plane) { - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const ImageFrame& image_frame) { - auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferWithImageFrame( - std::shared_ptr image_frame) { - return GpuBuffer( - std::make_shared(std::move(image_frame))); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferCopyingImageFrame( - const ImageFrame& image_frame) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); - // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only - // deals with absl::Status in MediaPipe OSS. - CHECK_OK(maybe_buffer.status()); - return GpuBuffer(std::move(maybe_buffer).value()); -#else - return GpuBuffer(GlTextureBuffer::Create(image_frame)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - view_->DoneWriting(); - std::shared_ptr view = - gpu_buffer_.GetReadView(); - auto copy = absl::make_unique(); - copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); - return copy; -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - view_->DoneWriting(); - return absl::make_unique(gpu_buffer_); -} - -GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( - int width, int height, GpuBufferFormat format) { - if (!framebuffer_) { - CreateFramebuffer(); - } - - GpuBuffer gpu_buffer = - gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); -} - -} // namespace mediapipe From 63e20896391dda07baa25733cc023db233945f8b Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:35:04 -0800 Subject: [PATCH 061/469] Deprecate a bunch of old stuff in GlCalculatorHelper PiperOrigin-RevId: 488813296 --- mediapipe/gpu/BUILD | 1 + mediapipe/gpu/gl_calculator_helper.h | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index b0c1c22b2..4fb59f1b5 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -784,6 +784,7 @@ cc_library( ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_cc_proto", + "@com_google_absl//absl/base:core_headers", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_contract", diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index 727be7826..af897bbe9 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -17,6 +17,7 @@ #include +#include "absl/base/attributes.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" @@ -61,6 +62,7 @@ class GlCalculatorHelper { // Can be used to initialize the helper outside of a calculator. Useful for // testing. void InitializeForTest(GpuResources* gpu_resources); + ABSL_DEPRECATED("Use InitializeForTest(GpuResources)") void InitializeForTest(GpuSharedData* gpu_shared); // This method can be called from GetContract to set up the needed GPU @@ -69,6 +71,7 @@ class GlCalculatorHelper { // This method can be called from FillExpectations to set the correct types // for the shared GL input side packet(s). + ABSL_DEPRECATED("Use UpdateContract") static absl::Status SetupInputSidePackets(PacketTypeSet* input_side_packets); // Execute the provided function within the helper's GL context. On some @@ -235,12 +238,14 @@ class GlTexture { // it is better to keep const-safety and accept having two versions of the // same thing. template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(const T& collection, const std::string& tag, int index) -> decltype(collection.Tag(tag)) { return collection.UsesTags() ? collection.Tag(tag) : collection.Index(index); } template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(T* collection, const std::string& tag, int index) -> decltype(collection->Tag(tag)) { return collection->UsesTags() ? collection->Tag(tag) @@ -248,12 +253,14 @@ auto TagOrIndex(T* collection, const std::string& tag, int index) } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(const T& collection, const std::string& tag, int index) { return collection.UsesTags() ? collection.HasTag(tag) : index < collection.NumEntries(); } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(T* collection, const std::string& tag, int index) { return collection->UsesTags() ? collection->HasTag(tag) : index < collection->NumEntries(); From febfc2029b38411a1835175d0bf3a647684475d9 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:35:32 -0800 Subject: [PATCH 062/469] Annotate plane argument PiperOrigin-RevId: 488813363 --- mediapipe/gpu/gl_texture_buffer.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index e57195a46..09703d89d 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -360,7 +360,7 @@ static std::shared_ptr ConvertToImageFrame( absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); buf->GetProducerContext()->Run([buf, &output] { - auto view = buf->GetReadView(internal::types{}, 0); + auto view = buf->GetReadView(internal::types{}, /*plane=*/0); ReadTexture(view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); @@ -389,8 +389,9 @@ static std::shared_ptr ConvertToCvPixelBuffer( buf->width(), buf->height(), buf->format()); buf->GetProducerContext()->Run([buf, &output] { TempGlFramebuffer framebuffer; - auto src = buf->GetReadView(internal::types{}, 0); - auto dst = output->GetWriteView(internal::types{}, 0); + auto src = buf->GetReadView(internal::types{}, /*plane=*/0); + auto dst = + output->GetWriteView(internal::types{}, /*plane=*/0); CopyGlTexture(src, dst); glFlush(); }); From f7aef677fc1830af167a4ae989b8ca5abcac485a Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 15 Nov 2022 18:59:06 -0800 Subject: [PATCH 063/469] Add running mode to all vision tasks PiperOrigin-RevId: 488816785 --- mediapipe/tasks/web/vision/core/BUILD | 25 +++++-- ...nning_mode.ts => vision_task_options.d.ts} | 27 ++++---- .../web/vision/core/vision_task_runner.ts | 66 +++++++++++++++++++ .../tasks/web/vision/gesture_recognizer/BUILD | 5 +- .../gesture_recognizer/gesture_recognizer.ts | 48 +++++++++----- .../gesture_recognizer_options.d.ts | 7 +- .../tasks/web/vision/hand_landmarker/BUILD | 5 +- .../vision/hand_landmarker/hand_landmarker.ts | 47 ++++++++----- .../hand_landmarker_options.d.ts | 7 +- .../tasks/web/vision/image_classifier/BUILD | 5 +- .../image_classifier/image_classifier.ts | 51 +++++++++----- .../image_classifier_options.d.ts | 7 +- .../tasks/web/vision/image_embedder/BUILD | 8 +-- .../vision/image_embedder/image_embedder.ts | 49 ++++++-------- .../image_embedder_options.d.ts | 15 +---- .../tasks/web/vision/object_detector/BUILD | 5 +- .../vision/object_detector/object_detector.ts | 45 +++++++++---- .../object_detector_options.d.ts | 7 +- 18 files changed, 281 insertions(+), 148 deletions(-) rename mediapipe/tasks/web/vision/core/{running_mode.ts => vision_task_options.d.ts} (58%) create mode 100644 mediapipe/tasks/web/vision/core/vision_task_runner.ts diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 7ab822b7c..8c405ae6e 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,11 +1,26 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( - name = "running_mode", - srcs = ["running_mode.ts"], - deps = ["//mediapipe/tasks/cc/core/proto:base_options_jspb_proto"], +mediapipe_ts_declaration( + name = "vision_task_options", + srcs = ["vision_task_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + ], +) + +mediapipe_ts_library( + name = "vision_task_runner", + srcs = ["vision_task_runner.ts"], + deps = [ + ":vision_task_options", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], ) diff --git a/mediapipe/tasks/web/vision/core/running_mode.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts similarity index 58% rename from mediapipe/tasks/web/vision/core/running_mode.ts rename to mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 1e9b1b9a7..8b9562e46 100644 --- a/mediapipe/tasks/web/vision/core/running_mode.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -14,23 +14,26 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {BaseOptions} from '../../../../tasks/web/core/base_options'; /** - * The running mode of a task. + * The two running modes of a video task. * 1) The image mode for processing single image inputs. * 2) The video mode for processing decoded frames of a video. */ export type RunningMode = 'image'|'video'; -/** Configues the `useStreamMode` option . */ -export function configureRunningMode( - options: {runningMode?: RunningMode}, - proto?: BaseOptionsProto): BaseOptionsProto { - proto = proto ?? new BaseOptionsProto(); - if ('runningMode' in options) { - const useStreamMode = options.runningMode === 'video'; - proto.setUseStreamMode(useStreamMode); - } - return proto; + +/** The options for configuring a MediaPipe vision task. */ +export declare interface VisionTaskOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The running mode of the task. Default to the image mode. + * Vision tasks have two running modes: + * 1) The image mode for processing single image inputs. + * 2) The video mode for processing decoded frames of a video. + */ + runningMode?: RunningMode; } diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts new file mode 100644 index 000000000..372ce9ba7 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -0,0 +1,66 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; + +import {VisionTaskOptions} from './vision_task_options'; + +/** Base class for all MediaPipe Vision Tasks. */ +export abstract class VisionTaskRunner extends TaskRunner { + protected abstract baseOptions?: BaseOptionsProto|undefined; + + /** Configures the shared options of a vision task. */ + async setOptions(options: VisionTaskOptions): Promise { + this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + if ('runningMode' in options) { + const useStreamMode = + !!options.runningMode && options.runningMode !== 'image'; + this.baseOptions.setUseStreamMode(useStreamMode); + } + } + + /** Sends an image packet to the graph and awaits results. */ + protected abstract process(input: ImageSource, timestamp: number): T; + + /** Sends a single image to the graph and awaits results. */ + protected processImageData(image: ImageSource): T { + if (!!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with image mode. ' + + '\'runningMode\' must be set to \'image\'.'); + } + return this.process(image, performance.now()); + } + + /** Sends a single video frame to the graph and awaits results. */ + protected processVideoData(imageFrame: ImageSource, timestamp: number): T { + if (!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with video mode. ' + + '\'runningMode\' must be set to \'video\'.'); + } + return this.process(imageFrame, timestamp); + } +} + + diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index d67974a16..f2b668239 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -19,6 +19,7 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto", @@ -27,11 +28,10 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -47,5 +47,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 6c8072ff5..8e745534e 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -19,6 +19,7 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb'; @@ -27,10 +28,9 @@ import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landm import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -64,7 +64,8 @@ FULL_IMAGE_RECT.setWidth(1); FULL_IMAGE_RECT.setHeight(1); /** Performs hand gesture recognition on images. */ -export class GestureRecognizer extends TaskRunner { +export class GestureRecognizer extends + VisionTaskRunner { private gestures: Category[][] = []; private landmarks: Landmark[][] = []; private worldLandmarks: Landmark[][] = []; @@ -156,10 +157,14 @@ export class GestureRecognizer extends TaskRunner { this.handGestureRecognizerGraphOptions); this.initDefaults(); + } - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); } /** @@ -171,12 +176,8 @@ export class GestureRecognizer extends TaskRunner { * * @param options The options for the gesture recognizer. */ - async setOptions(options: GestureRecognizerOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: GestureRecognizerOptions): Promise { + await super.setOptions(options); if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -233,12 +234,27 @@ export class GestureRecognizer extends TaskRunner { /** * Performs gesture recognition on the provided single image and waits * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image A single image to process. * @return The detected gestures. */ - recognize(imageSource: ImageSource, timestamp: number = performance.now()): + recognize(image: ImageSource): GestureRecognizerResult { + return this.processImageData(image); + } + + /** + * Performs gesture recognition on the provided video frame and waits + * synchronously for the response. + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The detected gestures. + */ + recognizeForVideo(videoFrame: ImageSource, timestamp: number): + GestureRecognizerResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the gesture recognition and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): GestureRecognizerResult { this.gestures = []; this.landmarks = []; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts index 45601a74c..dd8fc9548 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts @@ -14,14 +14,11 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Gesture Recognizer Task */ -export declare interface GestureRecognizerOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface GestureRecognizerOptions extends VisionTaskOptions { /** * The maximum number of hands can be detected by the GestureRecognizer. * Defaults to 1. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 25c70e0a5..36f1d7eb7 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -19,14 +19,14 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -41,5 +41,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index af10305b2..0aba5c82c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -19,14 +19,14 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -58,7 +58,7 @@ FULL_IMAGE_RECT.setWidth(1); FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ -export class HandLandmarker extends TaskRunner { +export class HandLandmarker extends VisionTaskRunner { private landmarks: Landmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -138,10 +138,14 @@ export class HandLandmarker extends TaskRunner { this.options.setHandDetectorGraphOptions(this.handDetectorGraphOptions); this.initDefaults(); + } - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); } /** @@ -153,12 +157,8 @@ export class HandLandmarker extends TaskRunner { * * @param options The options for the hand landmarker. */ - async setOptions(options: HandLandmarkerOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: HandLandmarkerOptions): Promise { + await super.setOptions(options); // Configure hand detector options. if ('numHands' in options) { @@ -186,12 +186,27 @@ export class HandLandmarker extends TaskRunner { /** * Performs hand landmarks detection on the provided single image and waits * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. * @return The detected hand landmarks. */ - detect(imageSource: ImageSource, timestamp: number = performance.now()): + detect(image: ImageSource): HandLandmarkerResult { + return this.processImageData(image); + } + + /** + * Performs hand landmarks detection on the provided video frame and waits + * synchronously for the response. + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The detected hand landmarks. + */ + detectForVideo(videoFrame: ImageSource, timestamp: number): + HandLandmarkerResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the hand landmarker graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): HandLandmarkerResult { this.landmarks = []; this.worldLandmarks = []; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts index 53ad9440a..fe79b7089 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts @@ -14,13 +14,10 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe HandLandmarker Task */ -export declare interface HandLandmarkerOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface HandLandmarkerOptions extends VisionTaskOptions { /** * The maximum number of hands can be detected by the HandLandmarker. * Defaults to 1. diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 8506f3574..e7e830332 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -16,15 +16,15 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -39,5 +39,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 5d60e4a21..0011e9c55 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -17,12 +17,12 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -42,7 +42,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs classification on images. */ -export class ImageClassifier extends TaskRunner { +export class ImageClassifier extends VisionTaskRunner { private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); @@ -105,6 +105,14 @@ export class ImageClassifier extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the image classifier. * @@ -114,28 +122,39 @@ export class ImageClassifier extends TaskRunner { * * @param options The options for the image classifier. */ - async setOptions(options: ImageClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: ImageClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Performs image classification on the provided image and waits synchronously - * for the response. + * Performs image classification on the provided single image and waits + * synchronously for the response. * - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. * @return The classification result of the image */ - classify(imageSource: ImageSource, timestamp?: number): + classify(image: ImageSource): ImageClassifierResult { + return this.processImageData(image); + } + + /** + * Performs image classification on the provided video frame and waits + * synchronously for the response. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The classification result of the image + */ + classifyForVideo(videoFrame: ImageSource, timestamp: number): + ImageClassifierResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the image classification graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): ImageClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts index a5f5c2386..c1141d28f 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Ooptions to configure the image classifier task. */ +export declare interface ImageClassifierOptions extends ClassifierOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 13ff2e4d6..ce1c25700 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -16,15 +16,15 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/tasks/web/vision/core:running_mode", + "//mediapipe/tasks/web/vision/core:vision_task_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -39,6 +39,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/vision/core:running_mode", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 91d9b5119..d17bc72fa 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -17,13 +17,12 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {configureRunningMode} from '../../../../tasks/web/vision/core/running_mode'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -43,7 +42,7 @@ export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ -export class ImageEmbedder extends TaskRunner { +export class ImageEmbedder extends VisionTaskRunner { private readonly options = new ImageEmbedderGraphOptions(); private embeddings: ImageEmbedderResult = {embeddings: []}; @@ -105,6 +104,14 @@ export class ImageEmbedder extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the image embedder. * @@ -114,24 +121,16 @@ export class ImageEmbedder extends TaskRunner { * * @param options The options for the image embedder. */ - async setOptions(options: ImageEmbedderOptions): Promise { - let baseOptionsProto = this.options.getBaseOptions(); - if (options.baseOptions) { - baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, baseOptionsProto); - } - baseOptionsProto = configureRunningMode(options, baseOptionsProto); - this.options.setBaseOptions(baseOptionsProto); - + override async setOptions(options: ImageEmbedderOptions): Promise { + await super.setOptions(options); this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); } /** - * Performs embedding extraction on the provided image and waits synchronously - * for the response. + * Performs embedding extraction on the provided single image and waits + * synchronously for the response. * * Only use this method when the `useStreamMode` option is not set or * expliclity set to `false`. @@ -140,12 +139,7 @@ export class ImageEmbedder extends TaskRunner { * @return The classification result of the image */ embed(image: ImageSource): ImageEmbedderResult { - if (!!this.options.getBaseOptions()?.getUseStreamMode()) { - throw new Error( - 'Task is not initialized with image mode. ' + - '\'runningMode\' must be set to \'image\'.'); - } - return this.performEmbeddingExtraction(image, performance.now()); + return this.processImageData(image); } /** @@ -160,16 +154,11 @@ export class ImageEmbedder extends TaskRunner { */ embedForVideo(imageFrame: ImageSource, timestamp: number): ImageEmbedderResult { - if (!this.options.getBaseOptions()?.getUseStreamMode()) { - throw new Error( - 'Task is not initialized with video mode. ' + - '\'runningMode\' must be set to \'video\' or \'live_stream\'.'); - } - return this.performEmbeddingExtraction(imageFrame, timestamp); + return this.processVideoData(imageFrame, timestamp); } - /** Runs the embedding extractio and blocks on the response. */ - private performEmbeddingExtraction(image: ImageSource, timestamp: number): + /** Runs the embedding extraction and blocks on the response. */ + protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { // Get embeddings by running our MediaPipe graph. this.addGpuBufferAsImageToStream( diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts index 4d795d0d8..10000825c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts @@ -15,17 +15,8 @@ */ import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; -import {RunningMode} from '../../../../tasks/web/vision/core/running_mode'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** The options for configuring a MediaPipe image embedder task. */ -export declare interface ImageEmbedderOptions extends EmbedderOptions { - /** - * The running mode of the task. Default to the image mode. - * Image embedder has three running modes: - * 1) The image mode for embedding image on single image inputs. - * 2) The video mode for embedding image on the decoded frames of a video. - * 3) The live stream mode for embedding image on the live stream of input - * data, such as from camera. - */ - runningMode?: RunningMode; -} +export declare interface ImageEmbedderOptions extends EmbedderOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index a74dc9211..0975a9fd4 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -17,11 +17,11 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -35,5 +35,6 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e17a42020..e6cbd8627 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -17,10 +17,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs object detection on images. */ -export class ObjectDetector extends TaskRunner { +export class ObjectDetector extends VisionTaskRunner { private detections: Detection[] = []; private readonly options = new ObjectDetectorOptionsProto(); @@ -103,6 +103,14 @@ export class ObjectDetector extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the object detector. * @@ -112,12 +120,8 @@ export class ObjectDetector extends TaskRunner { * * @param options The options for the object detector. */ - async setOptions(options: ObjectDetectorOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: ObjectDetectorOptions): Promise { + await super.setOptions(options); // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to @@ -158,12 +162,27 @@ export class ObjectDetector extends TaskRunner { /** * Performs object detection on the provided single image and waits * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. * @return The list of detected objects */ - detect(imageSource: ImageSource, timestamp?: number): Detection[] { + detect(image: ImageSource): Detection[] { + return this.processImageData(image); + } + + /** + * Performs object detection on the provided vidoe frame and waits + * synchronously for the response. + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The list of detected objects + */ + detectForVideo(videoFrame: ImageSource, timestamp: number): Detection[] { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the object detector graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): + Detection[] { // Get detections by running our MediaPipe graph. this.detections = []; this.addGpuBufferAsImageToStream( diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts index eec12cf17..1d20ce1e2 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -14,13 +14,10 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export interface ObjectDetectorOptions extends VisionTaskOptions { /** * The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. From dc9578d2263f99c64ab503fb50b727330c7b06e0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 08:27:30 -0800 Subject: [PATCH 064/469] Internal change PiperOrigin-RevId: 488946809 --- mediapipe/tasks/cc/core/BUILD | 3 +++ mediapipe/tasks/cc/vision/image_segmenter/BUILD | 3 +++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 291dd29fe..f14457073 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -22,6 +22,9 @@ cc_library( name = "base_options", srcs = ["base_options.cc"], hdrs = ["base_options.h"], + visibility = [ + "//mediapipe/tasks:internal", + ], deps = [ ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 4c43a07f5..7206a45ea 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -20,6 +20,9 @@ cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], + visibility = [ + "//mediapipe/tasks:internal", + ], deps = [ ":image_segmenter_graph", "//mediapipe/framework/api2:builder", From cdd44e77b75da34287938dfe222e220a780f98c7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 10:03:11 -0800 Subject: [PATCH 065/469] Internal change PiperOrigin-RevId: 488969539 --- .../python/vision/gesture_recognizer/gesture_recognizer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 9bac22133..8a6e474d7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -93,7 +93,7 @@ class GestureRecognizerTest(tf.test.TestCase): tflite_file=gesture_classifier_tflite_file, size=[1, model.embedding_size]) - def _test_accuracy(self, model, threshold=0.5): + def _test_accuracy(self, model, threshold=0.25): # Test on _train_data because of our limited dataset size _, accuracy = model.evaluate(self._train_data) tf.compat.v1.logging.info(f'train accuracy: {accuracy}') From 512a531b9e09a681b1a6ee02a08ddf290a48a0f9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 10:30:23 -0800 Subject: [PATCH 066/469] Internal change PiperOrigin-RevId: 488977390 --- third_party/external_files.bzl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 1f0b00289..72ca95e66 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -90,8 +90,8 @@ def external_files(): http_file( name = "com_google_mediapipe_canned_gesture_classifier_tflite", - sha256 = "2fc7e279966a7a9e15fc869223793e390791fc61fdc0062f9bc7d0eef6be98a2", - urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668124189331326"], + sha256 = "ee121d85979de1b86126faabb0a0f4d2e4039c3e33e2cd687db50571001b24d0", + urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668550473107417"], ) http_file( @@ -294,8 +294,8 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_tflite", - sha256 = "54abe78de1d1cd5e3cdaa0dab01db18e3ec7e09a76e7c3b5fa278572f7a60977", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668124192126494"], + sha256 = "927e4f6cbe6451da6b4fd1485e2576a6f8dbd95062666661cbd9dea893c41d01", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668550476472972"], ) http_file( @@ -990,14 +990,14 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", - sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668124196996131"], + sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668550482128410"], ) http_file( name = "com_google_mediapipe_gesture_embedder_saved_model_pb", - sha256 = "f3a2870ba3ef537a4f6a5889ffc5b7061ad98f9fd96ec431a62116892f100659", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668124199460071"], + sha256 = "0082d37c5b85487fbf553e00a63f640945faf3da2d561a5f5a24c3194fecda6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], ) http_file( @@ -1038,12 +1038,12 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001", - sha256 = "9fdb750c4bac67afb9c0f61916510930b496cc47e7f89449aee2bec6b6ed0af8", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668124201918980"], + sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668550487965052"], ) http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_index", - sha256 = "3ccbcee9488fec4627d496abd9837997276b32b839a4d0ae434bd806fe380b86", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668124204353848"], + sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], ) From 74474d859e0891fc97b4038b7b8ecb9420c4b522 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 13:58:21 -0800 Subject: [PATCH 067/469] Update image_classifier demo with new ImageClassifierOption changes PiperOrigin-RevId: 489031381 --- .../vision/image_classifier/image_classifier_demo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py index 5832ea53a..f382e28aa 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str, data = image_classifier.Dataset.from_folder(data_dir) train_data, rest_data = data.split(0.8) validation_data, test_data = rest_data.split(0.5) - + model_options = image_classifier.ImageClassifierOptions( + supported_model=model_spec, + hparams=image_classifier.HParams(export_dir=export_dir), + ) model = image_classifier.ImageClassifier.create( - model_spec=model_spec, train_data=train_data, validation_data=validation_data, - hparams=image_classifier.HParams(model_dir=export_dir)) + options=model_options) _, acc = model.evaluate(test_data) print('Test accuracy: %f' % acc) @@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str, raise ValueError(f'Quantization: {quantization} is not recognized') model.export_model(quantization_config=quantization_config) - model.export_labels(export_dir) def main(_) -> None: From 3cdf0f65365c5f13673034e9abf9ebbbef90c0b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 16 Nov 2022 14:36:14 -0800 Subject: [PATCH 068/469] Fix a crash that occurred when a model returns fewer vector elements than before PiperOrigin-RevId: 489041814 --- mediapipe/web/graph_runner/wasm_mediapipe_lib.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts index 9ecf094ca..5f8040a33 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts @@ -406,7 +406,7 @@ export class WasmMediaPipeLib { */ setVectorListener( outputStreamName: string, callbackFcn: (data: T[]) => void) { - const buffer: T[] = []; + let buffer: T[] = []; this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; this.wasmModule.vectorListeners[outputStreamName] = (data: unknown, index: number, length: number) => { @@ -419,6 +419,7 @@ export class WasmMediaPipeLib { // the underlying data elements once we leave the scope of the // listener. callbackFcn(buffer); + buffer = []; } }; } From b6b72d5e4e9b8a3b176331489cae78cc3e9c77df Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 15:55:06 -0800 Subject: [PATCH 069/469] Add MuxCalculator test case where graph is being closed while SELECT has not been received. PiperOrigin-RevId: 489061902 --- .../calculators/core/mux_calculator_test.cc | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index 86d2fab42..a3ac8a27a 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -398,6 +398,99 @@ TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) { MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST(MuxCalculatorTest, HandlesCloseGracefully) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_TRUE(output_packets.empty()); +} + +TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + # This is required in order for EXPECT_DEATH to work everywhere + executor { name: "" type: "ApplicationThreadExecutor" } + + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + } + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + // Currently MuxCalculator crashes with a correct packet set from + // DefaultInputStreamHandler. The SELECT packet is missing at Timestamp 1000, + // and an empty packet is the correct representation of that. + EXPECT_DEATH( + { + (void)graph.CloseAllInputStreams(); + (void)graph.WaitUntilDone(); + }, + "Check failed: payload_"); +} + } // namespace } // namespace mediapipe From 90eb4a19d8593d366ddf7aed894d8bb1161da39c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 18:11:00 -0800 Subject: [PATCH 070/469] Internal change PiperOrigin-RevId: 489088227 --- mediapipe/framework/deps/status_builder.cc | 23 ++--------- mediapipe/framework/deps/status_builder.h | 19 +-------- .../framework/deps/status_builder_test.cc | 39 ------------------- mediapipe/framework/deps/status_macros.h | 29 +++++++------- 4 files changed, 20 insertions(+), 90 deletions(-) diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index 70775949d..0202b8689 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -97,39 +97,24 @@ absl::Status StatusBuilder::Impl::JoinMessageToStatus() { }()); } -StatusBuilder::Impl::Impl(const absl::Status& status, const char* file, - int line) - : status(status), line(line), file(file), stream() {} - -StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line) - : status(std::move(status)), line(line), file(file), stream() {} - StatusBuilder::Impl::Impl(const absl::Status& status, mediapipe::source_location location) - : status(status), - line(location.line()), - file(location.file_name()), - stream() {} + : status(status), location(location), stream() {} StatusBuilder::Impl::Impl(absl::Status&& status, mediapipe::source_location location) - : status(std::move(status)), - line(location.line()), - file(location.file_name()), - stream() {} + : status(std::move(status)), location(location), stream() {} StatusBuilder::Impl::Impl(const Impl& other) : status(other.status), - line(other.line), - file(other.file), + location(other.location), no_logging(other.no_logging), stream(other.stream.str()), join_style(other.join_style) {} StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) { status = other.status; - line = other.line; - file = other.file; + location = other.location; no_logging = other.no_logging; stream = std::ostringstream(other.stream.str()); join_style = other.join_style; diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index d2e40d575..ae11699d2 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -60,17 +60,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { ? nullptr : std::make_unique(absl::Status(code, ""), location)) {} - StatusBuilder(const absl::Status& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(original_status, file, line)) {} - - StatusBuilder(absl::Status&& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(std::move(original_status), file, - line)) {} - bool ok() const { return !impl_; } StatusBuilder& SetAppend() &; @@ -109,8 +98,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { kPrepend, }; - Impl(const absl::Status& status, const char* file, int line); - Impl(absl::Status&& status, const char* file, int line); Impl(const absl::Status& status, mediapipe::source_location location); Impl(absl::Status&& status, mediapipe::source_location location); Impl(const Impl&); @@ -120,10 +107,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // The status that the result will be based on. absl::Status status; - // The line to record if this file is logged. - int line; - // Not-owned: The file to record if this status is logged. - const char* file; + // The source location to record if this file is logged. + mediapipe::source_location location; // Logging disabled if true. bool no_logging = false; // The additional messages added with `<<`. This is nullptr when status_ is diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc index 560acd3c6..f517bb909 100644 --- a/mediapipe/framework/deps/status_builder_test.cc +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -33,21 +33,6 @@ TEST(StatusBuilder, OkStatusRvalue) { ASSERT_EQ(status, absl::OkStatus()); } -TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - -TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) { - const auto original_status = absl::OkStatus(); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - TEST(StatusBuilder, AnnotateMode) { absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, "original message"), @@ -60,30 +45,6 @@ TEST(StatusBuilder, AnnotateMode) { "original message; annotated message1 annotated message2"); } -TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, - "original message"), - "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - -TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) { - const auto original_status = - absl::Status(absl::StatusCode::kNotFound, "original message"); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - TEST(StatusBuilder, PrependModeLvalue) { StatusBuilder builder( absl::Status(absl::StatusCode::kInvalidArgument, "original message"), diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h index 757d99392..92bbf0b84 100644 --- a/mediapipe/framework/deps/status_macros.h +++ b/mediapipe/framework/deps/status_macros.h @@ -81,11 +81,11 @@ // MP_RETURN_IF_ERROR(foo.Method(args...)); // return absl::OkStatus(); // } -#define MP_RETURN_IF_ERROR(expr) \ - STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ - if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ - status_macro_internal_adaptor = {(expr), __FILE__, __LINE__}) { \ - } else /* NOLINT */ \ +#define MP_RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr), MEDIAPIPE_LOC}) { \ + } else /* NOLINT */ \ return status_macro_internal_adaptor.Consume() // Executes an expression `rexpr` that returns a `absl::StatusOr`. On @@ -156,14 +156,14 @@ return mediapipe::StatusBuilder( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__)) + MEDIAPIPE_LOC)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ mediapipe::StatusBuilder _( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__); \ + MEDIAPIPE_LOC); \ (void)_; /* error_expression is allowed to not use this variable */ \ return (error_expression)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ @@ -201,18 +201,17 @@ namespace status_macro_internal { // that declares a variable. class StatusAdaptorForMacros { public: - StatusAdaptorForMacros(const absl::Status& status, const char* file, int line) - : builder_(status, file, line) {} + StatusAdaptorForMacros(const absl::Status& status, source_location location) + : builder_(status, location) {} - StatusAdaptorForMacros(absl::Status&& status, const char* file, int line) - : builder_(std::move(status), file, line) {} + StatusAdaptorForMacros(absl::Status&& status, source_location location) + : builder_(std::move(status), location) {} - StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(const StatusBuilder& builder, + source_location /*location*/) : builder_(builder) {} - StatusAdaptorForMacros(StatusBuilder&& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(StatusBuilder&& builder, source_location /*location*/) : builder_(std::move(builder)) {} StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; From e66e88802c42610441dd9acfd193a9ff8e022231 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 16 Nov 2022 18:32:59 -0800 Subject: [PATCH 071/469] Change NPM Bundle to ESM PiperOrigin-RevId: 489091370 --- mediapipe/tasks/web/BUILD | 80 ++++++------------- mediapipe/tasks/web/audio.ts | 8 +- mediapipe/tasks/web/audio/BUILD | 12 --- mediapipe/tasks/web/audio/index.ts | 17 ---- mediapipe/tasks/web/package.json | 12 +-- mediapipe/tasks/web/rollup.config.iife.mjs | 21 ----- ...ollup.config.cjs.mjs => rollup.config.mjs} | 4 +- mediapipe/tasks/web/text.ts | 10 ++- mediapipe/tasks/web/text/BUILD | 13 --- mediapipe/tasks/web/text/index.ts | 18 ----- mediapipe/tasks/web/vision.ts | 22 ++++- mediapipe/tasks/web/vision/BUILD | 16 ---- mediapipe/tasks/web/vision/index.ts | 21 ----- 13 files changed, 67 insertions(+), 187 deletions(-) delete mode 100644 mediapipe/tasks/web/audio/index.ts delete mode 100644 mediapipe/tasks/web/rollup.config.iife.mjs rename mediapipe/tasks/web/{rollup.config.cjs.mjs => rollup.config.mjs} (86%) delete mode 100644 mediapipe/tasks/web/text/index.ts delete mode 100644 mediapipe/tasks/web/vision/index.ts diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index b8777e785..e9703e37a 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -24,35 +24,25 @@ mediapipe_files(srcs = [ mediapipe_ts_library( name = "audio_lib", srcs = ["audio.ts"], - deps = ["//mediapipe/tasks/web/audio:audio_lib"], -) - -rollup_bundle( - name = "audio_cjs_bundle", - config_file = "rollup.config.cjs.mjs", - entry_point = "audio.ts", - format = "cjs", - output_dir = False, deps = [ - ":audio_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", + "//mediapipe/tasks/web/audio/audio_classifier", ], ) rollup_bundle( - name = "audio_iife_bundle", - config_file = "rollup.config.iife.mjs", + name = "audio_bundle", + config_file = "rollup.config.mjs", entry_point = "audio.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -69,8 +59,7 @@ pkg_npm( deps = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", - ":audio_cjs_bundle", - ":audio_iife_bundle", + ":audio_bundle", ], ) @@ -79,35 +68,26 @@ pkg_npm( mediapipe_ts_library( name = "text_lib", srcs = ["text.ts"], - deps = ["//mediapipe/tasks/web/text:text_lib"], -) - -rollup_bundle( - name = "text_cjs_bundle", - config_file = "rollup.config.cjs.mjs", - entry_point = "text.ts", - format = "cjs", - output_dir = False, deps = [ - ":text_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", + "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", ], ) rollup_bundle( - name = "text_iife_bundle", - config_file = "rollup.config.iife.mjs", + name = "text_bundle", + config_file = "rollup.config.mjs", entry_point = "text.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -124,8 +104,7 @@ pkg_npm( deps = [ "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", - ":text_cjs_bundle", - ":text_iife_bundle", + ":text_bundle", ], ) @@ -134,35 +113,29 @@ pkg_npm( mediapipe_ts_library( name = "vision_lib", srcs = ["vision.ts"], - deps = ["//mediapipe/tasks/web/vision:vision_lib"], -) - -rollup_bundle( - name = "vision_cjs_bundle", - config_file = "rollup.config.cjs.mjs", - entry_point = "vision.ts", - format = "cjs", - output_dir = False, deps = [ - ":vision_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", + "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", + "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/object_detector", ], ) rollup_bundle( - name = "vision_iife_bundle", - config_file = "rollup.config.iife.mjs", + name = "vision_bundle", + config_file = "rollup.config.mjs", entry_point = "vision.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -179,7 +152,6 @@ pkg_npm( deps = [ "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", - ":vision_cjs_bundle", - ":vision_iife_bundle", + ":vision_bundle", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 4a3b80594..764fd8393 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,4 +14,10 @@ * limitations under the License. */ -export * from '../../tasks/web/audio/index'; +import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const AudioClassifier = AudioClassifierImpl; + +export {AudioClassifier}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..69b0408e9 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -1,13 +1 @@ # This contains the MediaPipe Audio Tasks. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_library( - name = "audio_lib", - srcs = ["index.ts"], - deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", - ], -) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts deleted file mode 100644 index a5083b326..000000000 --- a/mediapipe/tasks/web/audio/index.ts +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; diff --git a/mediapipe/tasks/web/package.json b/mediapipe/tasks/web/package.json index 1870f18a6..89c9a599e 100644 --- a/mediapipe/tasks/web/package.json +++ b/mediapipe/tasks/web/package.json @@ -2,20 +2,10 @@ "name": "@mediapipe/tasks-__NAME__", "version": "__VERSION__", "description": "__DESCRIPTION__", - "main": "__NAME___cjs_bundle.js", - "module": "__NAME___cjs_bundle.js", - "jsdeliver": "__NAME___iife_bundle.js", - "exports": { - ".": "./__NAME___cjs_bundle.js", - "./loader": "./wasm/__NAME___wasm_internal.js", - "./wasm": "./wasm/__NAME___wasm_internal.wasm" - }, + "main": "__NAME___bundle.js", "author": "mediapipe@google.com", "license": "Apache-2.0", "types": "__TYPES__", - "dependencies": { - "google-protobuf": "^3.21.2" - }, "homepage": "http://mediapipe.dev", "keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ] } diff --git a/mediapipe/tasks/web/rollup.config.iife.mjs b/mediapipe/tasks/web/rollup.config.iife.mjs deleted file mode 100644 index 1320927aa..000000000 --- a/mediapipe/tasks/web/rollup.config.iife.mjs +++ /dev/null @@ -1,21 +0,0 @@ -import resolve from '@rollup/plugin-node-resolve'; -import commonjs from '@rollup/plugin-commonjs'; -import terser from '@rollup/plugin-terser'; -import replace from '@rollup/plugin-replace'; - -export default { - output: { - name: 'bundle', - sourcemap: false - }, - plugins: [ - // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 - replace({ - 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', - delimiters: ['', ''] - }), - resolve({browser: true}), - commonjs(), - terser() - ] -} diff --git a/mediapipe/tasks/web/rollup.config.cjs.mjs b/mediapipe/tasks/web/rollup.config.mjs similarity index 86% rename from mediapipe/tasks/web/rollup.config.cjs.mjs rename to mediapipe/tasks/web/rollup.config.mjs index 5f8ca1848..e633bf702 100644 --- a/mediapipe/tasks/web/rollup.config.cjs.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,6 +1,7 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; import replace from '@rollup/plugin-replace'; +import terser from '@rollup/plugin-terser'; export default { plugins: [ @@ -10,6 +11,7 @@ export default { delimiters: ['', ''] }), resolve(), - commonjs() + commonjs(), + terser() ] } diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index f8a0b6457..39d101237 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,4 +14,12 @@ * limitations under the License. */ -export * from '../../tasks/web/text/index'; +import {TextClassifier as TextClassifierImpl} from '../../tasks/web/text/text_classifier/text_classifier'; +import {TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/text_embedder/text_embedder'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 4b465b0f5..edd23c7d4 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -1,14 +1 @@ # This contains the MediaPipe Text Tasks. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_library( - name = "text_lib", - srcs = ["index.ts"], - deps = [ - "//mediapipe/tasks/web/text/text_classifier", - "//mediapipe/tasks/web/text/text_embedder", - ], -) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts deleted file mode 100644 index d50db209c..000000000 --- a/mediapipe/tasks/web/text/index.ts +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from '../../../tasks/web/text/text_classifier/text_classifier'; -export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 6ff8f725b..4e4fab43f 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,4 +14,24 @@ * limitations under the License. */ -export * from '../../tasks/web/vision/index'; +import {GestureRecognizer as GestureRecognizerImpl} from '../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +import {HandLandmarker as HandLandmarkerImpl} from '../../tasks/web/vision/hand_landmarker/hand_landmarker'; +import {ImageClassifier as ImageClassifierImpl} from '../../tasks/web/vision/image_classifier/image_classifier'; +import {ImageEmbedder as ImageEmbedderImpl} from '../../tasks/web/vision/image_embedder/image_embedder'; +import {ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/object_detector/object_detector'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ObjectDetector = ObjectDetectorImpl; + +export { + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ObjectDetector +}; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 3c45fbfa6..7267744e2 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -1,17 +1 @@ # This contains the MediaPipe Vision Tasks. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_library( - name = "vision_lib", - srcs = ["index.ts"], - deps = [ - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/hand_landmarker", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_embedder", - "//mediapipe/tasks/web/vision/object_detector", - ], -) diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts deleted file mode 100644 index d68c00cc7..000000000 --- a/mediapipe/tasks/web/vision/index.ts +++ /dev/null @@ -1,21 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from '../../../tasks/web/vision/image_classifier/image_classifier'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -export * from '../../../tasks/web/vision/object_detector/object_detector'; From 6fc277ee1c34eeba9fda1e7fde90b705a4ee5824 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 16 Nov 2022 18:34:14 -0800 Subject: [PATCH 072/469] Internal change PiperOrigin-RevId: 489091534 --- mediapipe/gpu/gl_context.cc | 8 ++++++-- mediapipe/gpu/gl_context.h | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 7f7ba0e23..91d2837c5 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -826,10 +826,14 @@ std::shared_ptr GlContext::CreateSyncToken() { return token; } -bool GlContext::IsAnyContextCurrent() { +PlatformGlContext GlContext::GetCurrentNativeContext() { ContextBinding ctx; GetCurrentContextBinding(&ctx); - return ctx.context != kPlatformGlContextNone; + return ctx.context; +} + +bool GlContext::IsAnyContextCurrent() { + return GetCurrentNativeContext() != kPlatformGlContextNone; } std::shared_ptr diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 957cb510f..7f5168d8b 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -307,6 +307,10 @@ class GlContext : public std::enable_shared_from_this { // the GlContext class, is current. static bool IsAnyContextCurrent(); + // Returns the current native context, whether managed by this class or not. + // Useful as a cross-platform way to get the current PlatformGlContext. + static PlatformGlContext GetCurrentNativeContext(); + // Creates a synchronization token for the current, non-GlContext-owned // context. This can be passed to MediaPipe so it can synchronize with the // commands issued in the external context up to this point. From 899c87466ec7cc62b5b60f10564c997c49bc9395 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 16 Nov 2022 20:55:18 -0800 Subject: [PATCH 073/469] Add MP Tasks entrypoints PiperOrigin-RevId: 489110875 --- mediapipe/tasks/web/audio/BUILD | 12 ++++++++++++ mediapipe/tasks/web/audio/index.ts | 17 +++++++++++++++++ mediapipe/tasks/web/text/BUILD | 13 +++++++++++++ mediapipe/tasks/web/text/index.ts | 18 ++++++++++++++++++ mediapipe/tasks/web/vision/BUILD | 16 ++++++++++++++++ mediapipe/tasks/web/vision/index.ts | 21 +++++++++++++++++++++ 6 files changed, 97 insertions(+) create mode 100644 mediapipe/tasks/web/audio/index.ts create mode 100644 mediapipe/tasks/web/text/index.ts create mode 100644 mediapipe/tasks/web/vision/index.ts diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 69b0408e9..4f6e48b28 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -1 +1,13 @@ # This contains the MediaPipe Audio Tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "audio_lib", + srcs = ["index.ts"], + deps = [ + "//mediapipe/tasks/web/audio/audio_classifier", + ], +) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts new file mode 100644 index 000000000..a5083b326 --- /dev/null +++ b/mediapipe/tasks/web/audio/index.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index edd23c7d4..4b465b0f5 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -1 +1,14 @@ # This contains the MediaPipe Text Tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "text_lib", + srcs = ["index.ts"], + deps = [ + "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", + ], +) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts new file mode 100644 index 000000000..d50db209c --- /dev/null +++ b/mediapipe/tasks/web/text/index.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 7267744e2..3c45fbfa6 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -1 +1,17 @@ # This contains the MediaPipe Vision Tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "vision_lib", + srcs = ["index.ts"], + deps = [ + "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", + "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/object_detector", + ], +) diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts new file mode 100644 index 000000000..d68c00cc7 --- /dev/null +++ b/mediapipe/tasks/web/vision/index.ts @@ -0,0 +1,21 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +export * from '../../../tasks/web/vision/object_detector/object_detector'; From 5a6837d034f9583e2f43659c388638ac14ad0b7e Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 16 Nov 2022 22:08:52 -0800 Subject: [PATCH 074/469] Fix errors that will occur in python 3.11 --- mediapipe/tasks/python/audio/audio_classifier.py | 3 ++- mediapipe/tasks/python/audio/audio_embedder.py | 3 ++- mediapipe/tasks/python/text/text_classifier.py | 4 +++- mediapipe/tasks/python/text/text_embedder.py | 4 +++- mediapipe/tasks/python/vision/gesture_recognizer.py | 6 ++++-- mediapipe/tasks/python/vision/image_classifier.py | 3 ++- mediapipe/tasks/python/vision/image_embedder.py | 3 ++- 7 files changed, 18 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index 7955cc4dc..2dd1cc4a3 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -70,7 +70,8 @@ class AudioClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index a774d71e9..4484064ee 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -71,7 +71,8 @@ class AudioEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 92d547f20..c6095e1c3 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,6 +14,7 @@ """MediaPipe text classifier task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -48,7 +49,8 @@ class TextClassifierOptions: classifier_options: Options for the text classification task. """ base_options: _BaseOptions - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index f3e5eecbe..1a32796a3 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -14,6 +14,7 @@ """MediaPipe text embedder task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -49,7 +50,8 @@ class TextEmbedderOptions: embedder_options: Options for the text embedder task. """ base_options: _BaseOptions - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 9b6fd8cab..8addebe4c 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -181,9 +181,11 @@ class GestureRecognizerOptions: min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 canned_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) custom_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[ [GestureRecognizerResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 763160e1e..d3c2965ba 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -70,7 +70,8 @@ class ImageClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index f299fa590..06624d16e 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -69,7 +69,8 @@ class ImageEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None From ea4989b6f146b9589fdd048ec4702a7c5384fe52 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Thu, 17 Nov 2022 00:06:17 -0800 Subject: [PATCH 075/469] Internal change PiperOrigin-RevId: 489135553 --- .../core/flow_limiter_calculator_test.cc | 96 ++------- mediapipe/framework/BUILD | 1 + mediapipe/framework/calculator_graph.cc | 26 ++- mediapipe/framework/calculator_graph.h | 6 + .../framework/calculator_graph_bounds_test.cc | 194 +++++++++++++++++- mediapipe/util/packet_test_util.h | 80 +++++++- 6 files changed, 302 insertions(+), 101 deletions(-) diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 45bace271..5d0594de9 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -85,75 +85,6 @@ std::string SourceString(Timestamp t) { : absl::StrCat("Timestamp(", t.DebugString(), ")"); } -template -std::string SourceString(Packet packet) { - std::ostringstream oss; - if (packet.IsEmpty()) { - oss << "Packet()"; - } else { - oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" - << packet.Get() << ")"; - } - oss << ".At(" << SourceString(packet.Timestamp()) << ")"; - return oss.str(); -} - -template -class PacketsEqMatcher - : public ::testing::MatcherInterface { - public: - PacketsEqMatcher(PacketContainer packets) : packets_(packets) {} - void DescribeTo(::std::ostream* os) const override { - *os << "The expected packet contents: \n"; - Print(packets_, os); - } - bool MatchAndExplain( - const PacketContainer& value, - ::testing::MatchResultListener* listener) const override { - if (!Equals(packets_, value)) { - if (listener->IsInterested()) { - *listener << "The actual packet contents: \n"; - Print(value, listener->stream()); - } - return false; - } - return true; - } - - private: - bool Equals(const PacketContainer& c1, const PacketContainer& c2) const { - if (c1.size() != c2.size()) { - return false; - } - for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) { - Packet p1 = *i1, p2 = *i2; - if (p1.Timestamp() != p2.Timestamp() || p1.IsEmpty() != p2.IsEmpty() || - (!p1.IsEmpty() && - p1.Get() != p2.Get())) { - return false; - } - } - return true; - } - void Print(const PacketContainer& packets, ::std::ostream* os) const { - for (auto it = packets.begin(); it != packets.end(); ++it) { - const Packet& packet = *it; - *os << (it == packets.begin() ? "{" : ""); - *os << SourceString(packet); - *os << (std::next(it) == packets.end() ? "}" : ", "); - } - } - - const PacketContainer packets_; -}; - -template -::testing::Matcher PacketsEq( - const PacketContainer& packets) { - return MakeMatcher( - new PacketsEqMatcher(packets)); -} - // A Calculator::Process callback function. typedef std::function @@ -743,9 +674,6 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { // The processing time "sleep_time" is reduced from 22ms to 12ms to create // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -839,13 +767,16 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); + // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_2_packets, IntPacketsEq(expected_output_2)); + EXPECT_THAT(out_2_packets, + ElementsAreArray(PacketMatchers(expected_output_2))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -871,7 +802,8 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { MakePacket(true).At(Timestamp(190000)), MakePacket(false).At(Timestamp(200000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } std::vector StripBoundsUpdates(const std::vector& packets, @@ -891,9 +823,6 @@ std::vector StripBoundsUpdates(const std::vector& packets, // Shows how FlowLimiterCalculator releases auxiliary input packets. // In this test, auxiliary input packets arrive at twice the primary rate. TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -1011,7 +940,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(6).At(Timestamp(60000)), Packet().At(Timestamp(80000)), }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); // Packets following input packets 2 and 6, and not input packets 4 and 8. std::vector expected_auxiliary_output = { @@ -1031,12 +961,13 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { }; std::vector actual_2 = StripBoundsUpdates(out_2_packets, Timestamp(90000)); - EXPECT_THAT(actual_2, IntPacketsEq(expected_auxiliary_output)); + EXPECT_THAT(actual_2, + ElementsAreArray(PacketMatchers(expected_auxiliary_output))); std::vector expected_3 = StripBoundsUpdates(expected_auxiliary_output, Timestamp(39999)); std::vector actual_3 = StripBoundsUpdates(out_3_packets, Timestamp(39999)); - EXPECT_THAT(actual_3, IntPacketsEq(expected_3)); + EXPECT_THAT(actual_3, ElementsAreArray(PacketMatchers(expected_3))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -1045,7 +976,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(true).At(Timestamp(60000)), MakePacket(false).At(Timestamp(80000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } } // anonymous namespace diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 19c51853c..8ccdac3b9 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1469,6 +1469,7 @@ cc_test( "//mediapipe/framework/stream_handler:mux_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/framework/tool:sink", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index c17a2e1e2..526a74835 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -98,14 +98,13 @@ void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) { manager_->LockIntroData(); } +void CalculatorGraph::GraphInputStream::SetNextTimestampBound( + Timestamp timestamp) { + shard_.SetNextTimestampBound(timestamp); +} + void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() { - // Since GraphInputStream doesn't allow SetOffset() and - // SetNextTimestampBound(), the timestamp bound to propagate is only - // determined by the timestamp of the output packets. - CHECK(!shard_.IsEmpty()) << "Shard with name \"" << manager_->Name() - << "\" failed"; - manager_->PropagateUpdatesToMirrors( - shard_.LastAddedPacketTimestamp().NextAllowedInStream(), &shard_); + manager_->PropagateUpdatesToMirrors(shard_.NextTimestampBound(), &shard_); } void CalculatorGraph::GraphInputStream::Close() { @@ -868,6 +867,19 @@ absl::Status CalculatorGraph::AddPacketToInputStream( return AddPacketToInputStreamInternal(stream_name, std::move(packet)); } +absl::Status CalculatorGraph::SetInputStreamTimestampBound( + const std::string& stream_name, Timestamp timestamp) { + std::unique_ptr* stream = + mediapipe::FindOrNull(graph_input_streams_, stream_name); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "SetInputStreamTimestampBound called on input stream \"$0\" which is not " + "a graph input stream.", + stream_name); + (*stream)->SetNextTimestampBound(timestamp); + (*stream)->PropagateUpdatesToMirrors(); + return absl::OkStatus(); +} + // We avoid having two copies of this code for AddPacketToInputStream( // const Packet&) and AddPacketToInputStream(Packet &&) by having this // internal-only templated version. T&& is a forwarding reference here, so diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index c51476102..04f9de45f 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -257,6 +257,10 @@ class CalculatorGraph { absl::Status AddPacketToInputStream(const std::string& stream_name, Packet&& packet); + // Indicates that input will arrive no earlier than a certain timestamp. + absl::Status SetInputStreamTimestampBound(const std::string& stream_name, + Timestamp timestamp); + // Sets the queue size of a graph input stream, overriding the graph default. absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name, int max_queue_size); @@ -425,6 +429,8 @@ class CalculatorGraph { void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); } + void SetNextTimestampBound(Timestamp timestamp); + void PropagateUpdatesToMirrors(); void Close(); diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index b55f9459d..d149337cc 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/strings/str_replace.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,6 +26,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { namespace { @@ -1536,7 +1539,7 @@ class EmptyPacketCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(EmptyPacketCalculator); -// This test shows that an output timestamp bound can be specified by outputing +// This test shows that an output timestamp bound can be specified by outputting // an empty packet with a settled timestamp. TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { // OffsetAndBoundCalculator runs on parallel threads and sends ts @@ -1580,6 +1583,195 @@ TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); } + // Shut down the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows that input timestamp bounds can be specified using +// CalculatorGraph::SetInputStreamTimestampBound. +TEST(CalculatorGraphBoundsTest, SetInputStreamTimestampBound) { + std::string config_str = R"( + input_stream: "input_0" + node { + calculator: "ProcessBoundToPacketCalculator" + input_stream: "input_0" + output_stream: "output_0" + } + )"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector output_0_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { + output_0_packets.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send in timestamp bounds. + for (int i = 0; i < 9; ++i) { + const int ts = 10 + i * 10; + MP_ASSERT_OK(graph.SetInputStreamTimestampBound( + "input_0", Timestamp(ts).NextAllowedInStream())); + MP_ASSERT_OK(graph.WaitUntilIdle()); + } + + // 9 timestamp bounds are converted to packets. + EXPECT_EQ(output_0_packets.size(), 9); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); + } + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows how an input stream with infrequent packets, such as +// configuration protobufs, can be consumed while processing more frequent +// packets, such as video frames. +TEST(CalculatorGraphBoundsTest, TimestampBoundsForInfrequentInput) { + // PassThroughCalculator consuming two input streams, with default ISH. + std::string config_str = R"pb( + input_stream: "INFREQUENT:config" + input_stream: "FREQUENT:frame" + node { + calculator: "PassThroughCalculator" + input_stream: "CONFIG:config" + input_stream: "VIDEO:frame" + output_stream: "VIDEO:output_frame" + output_stream: "CONFIG:output_config" + } + )pb"; + + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector frame_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_frame", + [&](const Packet& p) { + frame_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + std::vector config_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_config", + [&](const Packet& p) { + config_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Utility functions to send packets or timestamp bounds. + auto send_fn = [&](std::string stream, std::string value, int ts) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + stream, + MakePacket(absl::StrCat(value)).At(Timestamp(ts)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + auto bound_fn = [&](std::string stream, int ts) { + MP_ASSERT_OK(graph.SetInputStreamTimestampBound(stream, Timestamp(ts))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + // Send in a frame packet. + send_fn("frame", "frame_0", 0); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, ElementsAreArray(PacketMatchers({}))); + bound_fn("config", 10000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_1", 20000); + // The frame is not processed yet. + // The PassThroughCalculator with TimestampOffset 0 now propagates + // Timestamp bound 10000 to both "output_frame" and "output_config", + // which appears here as Packet().At(Timestamp(9999). The timestamp + // bounds at 29999 and 50000 are propagated similarly. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + }))); + bound_fn("config", 30000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_2", 40000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + }))); + send_fn("config", "config_1", 50000); + // The frame is processed after a fresh config arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_3", 60000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + }))); + bound_fn("config", 70000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + MakePacket("frame_3").At(Timestamp(60000)), + }))); + + // One config packet is deleivered. + EXPECT_THAT(config_packets, + ElementsAreArray(PacketMatchers({ + Packet().At(Timestamp(0)), + Packet().At(Timestamp(9999)), + Packet().At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + Packet().At(Timestamp(40000)), + MakePacket("config_1").At(Timestamp(50000)), + Packet().At(Timestamp(60000)), + }))); + // Shutdown the graph. MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); diff --git a/mediapipe/util/packet_test_util.h b/mediapipe/util/packet_test_util.h index 106d7f8d4..61e9322e1 100644 --- a/mediapipe/util/packet_test_util.h +++ b/mediapipe/util/packet_test_util.h @@ -32,30 +32,29 @@ namespace mediapipe { namespace internal { template -class PacketMatcher : public ::testing::MatcherInterface { +class PacketMatcher : public testing::MatcherInterface { public: template explicit PacketMatcher(InnerMatcher inner_matcher) : inner_matcher_( - ::testing::SafeMatcherCast(inner_matcher)) {} + testing::SafeMatcherCast(inner_matcher)) {} // Returns true iff the packet contains value of PayloadType satisfying // the inner matcher. - bool MatchAndExplain( - const Packet& packet, - ::testing::MatchResultListener* listener) const override { + bool MatchAndExplain(const Packet& packet, + testing::MatchResultListener* listener) const override { if (!packet.ValidateAsType().ok()) { *listener << packet.DebugString() << " does not contain expected type " << ExpectedTypeName(); return false; } - ::testing::StringMatchResultListener match_listener; + testing::StringMatchResultListener match_listener; const PayloadType& payload = packet.Get(); const bool matches = inner_matcher_.MatchAndExplain(payload, &match_listener); const std::string explanation = match_listener.str(); *listener << packet.DebugString() << " containing value " - << ::testing::PrintToString(payload); + << testing::PrintToString(payload); if (!explanation.empty()) { *listener << ", which " << explanation; } @@ -78,9 +77,28 @@ class PacketMatcher : public ::testing::MatcherInterface { return ::mediapipe::Demangle(typeid(PayloadType).name()); } - const ::testing::Matcher inner_matcher_; + const testing::Matcher inner_matcher_; }; +inline std::string SourceString(Timestamp t) { + return (t.IsSpecialValue()) + ? t.DebugString() + : absl::StrCat("Timestamp(", t.DebugString(), ")"); +} + +template +std::string SourceString(Packet packet) { + std::ostringstream oss; + if (packet.IsEmpty()) { + oss << "Packet()"; + } else { + oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" + << packet.Get() << ")"; + } + oss << ".At(" << SourceString(packet.Timestamp()) << ")"; + return oss.str(); +} + } // namespace internal // Creates matcher validating that the packet contains value of expected type @@ -91,9 +109,9 @@ class PacketMatcher : public ::testing::MatcherInterface { // // EXPECT_THAT(MakePacket(42), PacketContains(Eq(42))) template -inline ::testing::Matcher PacketContains( +inline testing::Matcher PacketContains( InnerMatcher inner_matcher) { - return ::testing::MakeMatcher( + return testing::MakeMatcher( new internal::PacketMatcher(inner_matcher)); } @@ -110,7 +128,7 @@ inline ::testing::Matcher PacketContains( // Eq(42))) template -inline ::testing::Matcher PacketContainsTimestampAndPayload( +inline testing::Matcher PacketContainsTimestampAndPayload( TimestampMatcher timestamp_matcher, ContentMatcher content_matcher) { return testing::AllOf( testing::Property("Packet::Timestamp", &Packet::Timestamp, @@ -118,6 +136,46 @@ inline ::testing::Matcher PacketContainsTimestampAndPayload( PacketContains(content_matcher)); } +template +class PacketEqMatcher : public testing::MatcherInterface { + public: + PacketEqMatcher(Packet packet) : packet_(packet) {} + void DescribeTo(::std::ostream* os) const override { + *os << "The expected packet: " << internal::SourceString(packet_); + } + bool MatchAndExplain(Packet value, + testing::MatchResultListener* listener) const override { + bool unequal = (value.Timestamp() != packet_.Timestamp() || + value.IsEmpty() != packet_.IsEmpty() || + (!value.IsEmpty() && value.Get() != packet_.Get())); + if (unequal && listener->IsInterested()) { + *listener << "The actual packet: " << internal::SourceString(value); + } + return !unequal; + } + const Packet packet_; +}; + +template +testing::Matcher PacketEq(Packet packet) { + return MakeMatcher(new PacketEqMatcher(packet)); +} + +template +std::vector> PacketMatchers( + std::vector packets) { + std::vector> result; + for (const auto& packet : packets) { + result.push_back(PacketEq(packet)); + } + return result; +} + +} // namespace mediapipe + +namespace mediapipe { +using mediapipe::PacketContains; +using mediapipe::PacketContainsTimestampAndPayload; } // namespace mediapipe #endif // MEDIAPIPE_UTIL_PACKET_TEST_UTIL_H_ From 3ccf7308e03933ceb6285e7f347d2865c7a4d540 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 17 Nov 2022 05:26:56 -0800 Subject: [PATCH 076/469] Add shared options for Text and Audio Tasks PiperOrigin-RevId: 489186644 --- .../audioembedder/AudioEmbedderResult.java | 4 +- .../tasks/audio/core/RunningMode.java | 2 +- .../tasks/web/audio/audio_classifier/BUILD | 1 + .../audio_classifier_options.d.ts | 7 ++- mediapipe/tasks/web/audio/core/BUILD | 13 ++++++ .../web/audio/core/audio_task_options.d.ts | 44 +++++++++++++++++++ .../tasks/web/core/classifier_options.d.ts | 5 +-- .../tasks/web/core/embedder_options.d.ts | 5 +-- mediapipe/tasks/web/text/core/BUILD | 11 +++++ .../web/text/core/text_task_options.d.ts | 23 ++++++++++ .../tasks/web/text/text_classifier/BUILD | 1 + .../text_classifier_options.d.ts | 7 ++- mediapipe/tasks/web/text/text_embedder/BUILD | 1 + .../text_embedder/text_embedder_options.d.ts | 7 ++- mediapipe/tasks/web/vision/core/BUILD | 2 +- .../web/vision/core/vision_task_options.d.ts | 2 +- .../image_classifier_options.d.ts | 2 +- .../image_embedder_options.d.ts | 2 +- 18 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 mediapipe/tasks/web/audio/core/BUILD create mode 100644 mediapipe/tasks/web/audio/core/audio_task_options.d.ts create mode 100644 mediapipe/tasks/web/text/core/BUILD create mode 100644 mediapipe/tasks/web/text/core/text_task_options.d.ts diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java index ee4df0198..a986048f0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java @@ -65,8 +65,8 @@ public abstract class AudioEmbedderResult implements TaskResult { /** * Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents - * one audio embedding result in an audio stream, and s only available when running with the audio - * stream mode. + * one audio embedding result in an audio stream, and is only available when running with the + * audio stream mode. */ public abstract Optional embeddingResult(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java index f0a123810..a778eae46 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java @@ -20,7 +20,7 @@ package com.google.mediapipe.tasks.audio.core; *
    *
  • AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips. *
  • AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from - * microphone. + * a microphone. *
*/ public enum RunningMode { diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 6a78116c3..412af3bea 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -36,6 +36,7 @@ mediapipe_ts_declaration( "audio_classifier_result.d.ts", ], deps = [ + "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts index 93bd9927e..975b1e315 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +/** Options to configure the MediaPipe Audio Classifier Task */ +export declare interface AudioClassifierOptions extends ClassifierOptions, + AudioTaskOptions {} diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD new file mode 100644 index 000000000..ed60f2435 --- /dev/null +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -0,0 +1,13 @@ +# This package contains options shared by all MediaPipe Audio Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_declaration( + name = "audio_task_options", + srcs = ["audio_task_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + ], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts new file mode 100644 index 000000000..58a6e55d8 --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts @@ -0,0 +1,44 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +/** + * MediaPipe audio task running mode. A MediaPipe audio task can be run with + * two different modes: + * - audio_clips: The mode for running a mediapipe audio task on independent + * audio clips. + * - audio_stream: The mode for running a mediapipe audio task on an audio + * stream, such as from a microphone. + * + */ +export type RunningMode = 'audio_clips'|'audio_stream'; + +/** The options for configuring a MediaPipe Audio Task. */ +export declare interface AudioTaskOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The running mode of the task. Default to the audio_clips mode. + * Audio tasks have two running modes: + * 1) The mode for running a mediapipe audio task on independent + * audio clips. + * 2) The mode for running a mediapipe audio task on an audio + * stream, such as from a microphone. + */ + runningMode?: RunningMode; +} diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts index 3dec8d27e..1d804d629 100644 --- a/mediapipe/tasks/web/core/classifier_options.d.ts +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -16,11 +16,8 @@ import {BaseOptions} from '../../../tasks/web/core/base_options'; -/** Options to configure the Mediapipe Classifier Task. */ +/** Options to configure a MediaPipe Classifier Task. */ export declare interface ClassifierOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - /** * The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts index 78ddad1ae..3ec2a170c 100644 --- a/mediapipe/tasks/web/core/embedder_options.d.ts +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -16,11 +16,8 @@ import {BaseOptions} from '../../../tasks/web/core/base_options'; -/** Options to configure the MediaPipe Embedder Task */ +/** Options to configure a MediaPipe Embedder Task */ export declare interface EmbedderOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - /** * Whether to normalize the returned feature vector with L2 norm. Use this * option only if the model does not already contain a native L2_NORMALIZATION diff --git a/mediapipe/tasks/web/text/core/BUILD b/mediapipe/tasks/web/text/core/BUILD new file mode 100644 index 000000000..3e7faec93 --- /dev/null +++ b/mediapipe/tasks/web/text/core/BUILD @@ -0,0 +1,11 @@ +# This package contains options shared by all MediaPipe Texxt Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_declaration( + name = "text_task_options", + srcs = ["text_task_options.d.ts"], + deps = ["//mediapipe/tasks/web/core"], +) diff --git a/mediapipe/tasks/web/text/core/text_task_options.d.ts b/mediapipe/tasks/web/text/core/text_task_options.d.ts new file mode 100644 index 000000000..4874e35bf --- /dev/null +++ b/mediapipe/tasks/web/text/core/text_task_options.d.ts @@ -0,0 +1,23 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +/** The options for configuring a MediaPipe Text task. */ +export declare interface TextTaskOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; +} diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 7dbbb18ca..8c3b8e226 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -40,5 +40,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts index 51b2b3947..b50767e1a 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; + +/** Options to configure the MediaPipe Text Classifier Task */ +export declare interface TextClassifierOptions extends ClassifierOptions, + TextTaskOptions {} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index bebd612dd..17b5eac06 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -39,5 +39,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts index 9af263765..9ea570304 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {EmbedderOptions as TextEmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; + +/** Options to configure the MediaPipe Text Embedder Task */ +export declare interface TextEmbedderOptions extends EmbedderOptions, + TextTaskOptions {} diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 8c405ae6e..e3a5edf33 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,4 +1,4 @@ -# This package contains options shared by all MediaPipe Tasks for Web. +# This package contains options shared by all MediaPipe Vision Tasks for Web. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 8b9562e46..e04eb6596 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -17,7 +17,7 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; /** - * The two running modes of a video task. + * The two running modes of a vision task. * 1) The image mode for processing single image inputs. * 2) The video mode for processing decoded frames of a video. */ diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts index c1141d28f..e99dd2b69 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts @@ -17,6 +17,6 @@ import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; -/** Ooptions to configure the image classifier task. */ +/** Options to configure the MediaPipe Image Classifier Task. */ export declare interface ImageClassifierOptions extends ClassifierOptions, VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts index 10000825c..8a04be5e1 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts @@ -17,6 +17,6 @@ import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; -/** The options for configuring a MediaPipe image embedder task. */ +/** Options for configuring a MediaPipe Image Embedder task. */ export declare interface ImageEmbedderOptions extends EmbedderOptions, VisionTaskOptions {} From 1fb0902aa06d45ebc73f5337d9f65f06c418c24b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 14:01:14 -0800 Subject: [PATCH 077/469] Update gesture_recognizer test PiperOrigin-RevId: 489301508 --- .../vision/gesture_recognizer/gesture_recognizer_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 8a6e474d7..39272cbbc 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import random import tempfile from unittest import mock as unittest_mock import zipfile @@ -41,6 +42,7 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() + random.seed(1234) all_data = self._load_data() # Splits data, 90% data for training, 10% for validation self._train_data, self._validation_data = all_data.split(0.9) @@ -93,11 +95,11 @@ class GestureRecognizerTest(tf.test.TestCase): tflite_file=gesture_classifier_tflite_file, size=[1, model.embedding_size]) - def _test_accuracy(self, model, threshold=0.25): + def _test_accuracy(self, model, threshold=0.0): # Test on _train_data because of our limited dataset size _, accuracy = model.evaluate(self._train_data) tf.compat.v1.logging.info(f'train accuracy: {accuracy}') - self.assertGreaterEqual(accuracy, threshold) + self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( gesture_recognizer.hyperparameters, From a7bd725e65e34ea416b15ceeffed972a2b205071 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 16:06:04 -0800 Subject: [PATCH 078/469] Internal change PiperOrigin-RevId: 489331826 --- mediapipe/gpu/gl_context.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 91d2837c5..53e3ff8b7 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -290,8 +290,15 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { // some Emscripten cases), there might be some existing tripped error. ForceClearExistingGlErrors(); - absl::string_view version_string( - reinterpret_cast(glGetString(GL_VERSION))); + absl::string_view version_string; + const GLubyte* version_string_ptr = glGetString(GL_VERSION); + if (version_string_ptr != nullptr) { + version_string = reinterpret_cast(version_string_ptr); + } else { + // This may happen when using SwiftShader, but the numeric versions are + // available and will be used instead. + LOG(WARNING) << "failed to get GL_VERSION string"; + } // We will decide later whether we want to use the version numbers we query // for, or instead derive that information from the context creation result, @@ -333,7 +340,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { } LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_ - << " (" << glGetString(GL_VERSION) << ")"; + << " (" << version_string << ")"; { auto status = GetGlExtensions(); if (!status.ok()) { From ab3a5f0fbf1883c4d1dfe1df2db80a7045a390c4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 16:28:08 -0800 Subject: [PATCH 079/469] Make MuxCalculator with DefaultInputStreamHandler to handle graph closure gracefully PiperOrigin-RevId: 489336722 --- mediapipe/calculators/core/mux_calculator.cc | 4 ++++ .../calculators/core/mux_calculator_test.cc | 16 ++++++---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index a0ce2ae34..88b04a32b 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -41,6 +41,10 @@ class MuxCalculator : public Node { StreamHandler("MuxInputStreamHandler")); absl::Status Process(CalculatorContext* cc) final { + if (kSelect(cc).IsStream() && kSelect(cc).IsEmpty()) { + return absl::OkStatus(); + } + int select = *kSelect(cc); RET_CHECK(0 <= select && select < kIn(cc).Count()); if (!kIn(cc)[select].IsEmpty()) { diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index a3ac8a27a..6b9434be9 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -439,7 +439,7 @@ TEST(MuxCalculatorTest, HandlesCloseGracefully) { EXPECT_TRUE(output_packets.empty()); } -TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { +TEST(MuxCalculatorTest, HandlesCloseGracefullyWithDeafultInputStreamHandler) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( R"pb( @@ -480,15 +480,11 @@ TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { MP_ASSERT_OK(graph.AddPacketToInputStream( "value_0", MakePacket(0).At(Timestamp(1000)))); MP_ASSERT_OK(graph.WaitUntilIdle()); - // Currently MuxCalculator crashes with a correct packet set from - // DefaultInputStreamHandler. The SELECT packet is missing at Timestamp 1000, - // and an empty packet is the correct representation of that. - EXPECT_DEATH( - { - (void)graph.CloseAllInputStreams(); - (void)graph.WaitUntilDone(); - }, - "Check failed: payload_"); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + ASSERT_EQ(output_packets.size(), 1); + EXPECT_TRUE(output_packets[0].IsEmpty()); } } // namespace From 6f3cb340e153af68c31462a337ee0bf1c113f7cd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 17:14:56 -0800 Subject: [PATCH 080/469] Internal change PiperOrigin-RevId: 489345940 --- .../tasks/web/audio/audio_classifier/BUILD | 2 +- .../audio_classifier/audio_classifier.ts | 2 +- mediapipe/tasks/web/core/BUILD | 4 ++-- mediapipe/tasks/web/core/task_runner.ts | 6 +++--- .../tasks/web/text/text_classifier/BUILD | 2 +- .../text/text_classifier/text_classifier.ts | 2 +- mediapipe/tasks/web/text/text_embedder/BUILD | 2 +- .../web/text/text_embedder/text_embedder.ts | 2 +- mediapipe/tasks/web/vision/core/BUILD | 2 +- .../web/vision/core/vision_task_runner.ts | 2 +- .../tasks/web/vision/gesture_recognizer/BUILD | 2 +- .../gesture_recognizer/gesture_recognizer.ts | 2 +- .../tasks/web/vision/hand_landmarker/BUILD | 2 +- .../vision/hand_landmarker/hand_landmarker.ts | 2 +- .../tasks/web/vision/image_classifier/BUILD | 2 +- .../image_classifier/image_classifier.ts | 2 +- .../tasks/web/vision/image_embedder/BUILD | 2 +- .../vision/image_embedder/image_embedder.ts | 2 +- .../tasks/web/vision/object_detector/BUILD | 2 +- .../vision/object_detector/object_detector.ts | 2 +- mediapipe/web/graph_runner/BUILD | 20 ++++++------------- ...{wasm_mediapipe_lib.ts => graph_runner.ts} | 14 ++++++------- ...image_lib.ts => graph_runner_image_lib.ts} | 10 +++++----- .../register_model_resources_graph_service.ts | 10 +++++----- 24 files changed, 46 insertions(+), 54 deletions(-) rename mediapipe/web/graph_runner/{wasm_mediapipe_lib.ts => graph_runner.ts} (99%) rename mediapipe/web/graph_runner/{wasm_mediapipe_image_lib.ts => graph_runner_image_lib.ts} (83%) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 412af3bea..9e1fcbc51 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 76b926723..5533b0eaa 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index e9ef85d46..6eca8bb4a 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,9 +18,9 @@ mediapipe_ts_library( "task_runner.ts", ], deps = [ + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c948930fc..67aa4e4df 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -15,12 +15,12 @@ */ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; -import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib'; -import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib'; +import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; +import {GraphRunner, WasmModule} from '../../../web/graph_runner/graph_runner'; // tslint:disable-next-line:enforce-name-casing const WasmMediaPipeImageLib = - SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib)); + SupportModelResourcesGraphService(SupportImage(GraphRunner)); /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner extends WasmMediaPipeImageLib { diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 8c3b8e226..71ef02c92 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -26,7 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index d4f413efa..04789f5e1 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 17b5eac06..c555f8d33 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 7c631683d..57b91d575 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -23,7 +23,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e3a5edf33..1d8944f14 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -21,6 +21,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 372ce9ba7..79ff45156 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -17,7 +17,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index f2b668239..ddfd1a327 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -32,7 +32,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8e745534e..dd050d0f1 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -31,7 +31,7 @@ import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 36f1d7eb7..1849687c5 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -27,7 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 0aba5c82c..32b1eed4b 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -27,7 +27,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {HandLandmarkerOptions} from './hand_landmarker_options'; diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index e7e830332..ebe64ecf4 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 0011e9c55..b59cb6fb1 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index ce1c25700..feb3ae054 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index d17bc72fa..c60665052 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -23,7 +23,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 0975a9fd4..b6bef6bfa 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -22,7 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e6cbd8627..44046cd1e 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; diff --git a/mediapipe/web/graph_runner/BUILD b/mediapipe/web/graph_runner/BUILD index dab6be50f..5c12947af 100644 --- a/mediapipe/web/graph_runner/BUILD +++ b/mediapipe/web/graph_runner/BUILD @@ -3,32 +3,24 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = [ - ":internal", "//mediapipe/tasks:internal", ]) -package_group( - name = "internal", - packages = [ - "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...", - ], -) - mediapipe_ts_library( - name = "wasm_mediapipe_lib_ts", + name = "graph_runner_ts", srcs = [ - ":wasm_mediapipe_lib.ts", + ":graph_runner.ts", ], allow_unoptimized_namespaces = True, ) mediapipe_ts_library( - name = "wasm_mediapipe_image_lib_ts", + name = "graph_runner_image_lib_ts", srcs = [ - ":wasm_mediapipe_image_lib.ts", + ":graph_runner_image_lib.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) mediapipe_ts_library( @@ -37,5 +29,5 @@ mediapipe_ts_library( ":register_model_resources_graph_service.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/graph_runner.ts similarity index 99% rename from mediapipe/web/graph_runner/wasm_mediapipe_lib.ts rename to mediapipe/web/graph_runner/graph_runner.ts index 5f8040a33..7de5aa33b 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -129,7 +129,7 @@ declare global { declare function importScripts(...urls: Array): void; /** - * Valid types of image sources which we can run our WasmMediaPipeLib over. + * Valid types of image sources which we can run our GraphRunner over. */ export type ImageSource = HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap; @@ -138,7 +138,7 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing WasmMediaPipeLib and +// Internal type of constructors used for initializing GraphRunner and // subclasses. type WasmMediaPipeConstructor = (new ( @@ -151,7 +151,7 @@ type WasmMediaPipeConstructor = * into canvas, or else return the output WebGLTexture. Takes a WebAssembly * Module (must be instantiated to self.Module). */ -export class WasmMediaPipeLib { +export class GraphRunner { // TODO: These should be protected/private, but are left exposed for // now so that we can use proper TS mixins with this class as a base. This // should be somewhat fixed when we create our .d.ts files. @@ -989,7 +989,7 @@ async function runScript(scriptUrl: string) { /** * Global function to initialize Wasm blob and load runtime assets for a * specialized MediaPipe library. This allows us to create a requested - * subclass inheriting from WasmMediaPipeLib. + * subclass inheriting from GraphRunner. * @param constructorFcn The name of the class to instantiate via "new". * @param wasmLoaderScript Url for the wasm-runner script; produced by the build * process. @@ -1043,12 +1043,12 @@ export async function createMediaPipeLib( * @return promise A promise which will resolve when initialization has * completed successfully. */ -export async function createWasmMediaPipeLib( +export async function createGraphRunner( wasmLoaderScript?: string, assetLoaderScript?: string, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, - fileLocator?: FileLocator): Promise { + fileLocator?: FileLocator): Promise { return createMediaPipeLib( - WasmMediaPipeLib, wasmLoaderScript, assetLoaderScript, glCanvas, + GraphRunner, wasmLoaderScript, assetLoaderScript, glCanvas, fileLocator); } diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts similarity index 83% rename from mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts rename to mediapipe/web/graph_runner/graph_runner_image_lib.ts index 3b45e8230..e886999cb 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -1,12 +1,12 @@ -import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {ImageSource, GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -19,10 +19,10 @@ export declare interface WasmImageModule { } /** - * An implementation of WasmMediaPipeLib that supports binding GPU image data as + * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);` + * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index e85d63b06..bc9c93e8a 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -1,12 +1,12 @@ -import {WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -17,11 +17,11 @@ export declare interface WasmModuleRegisterModelResources { } /** - * An implementation of WasmMediaPipeLib that supports registering model + * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * WasmMediaPipeLib);` + * GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService( From efcdedbd59a135d757a49b0ff27b656e793386ad Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Thu, 17 Nov 2022 18:14:58 -0800 Subject: [PATCH 081/469] Remove redundant _ios targets PiperOrigin-RevId: 489355333 --- mediapipe/gpu/BUILD | 14 -------------- mediapipe/objc/BUILD | 4 ++-- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 4fb59f1b5..27d91f21a 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -810,20 +810,6 @@ cc_library( }), ) -# TODO: remove -objc_library( - name = "gl_calculator_helper_ios", - copts = [ - "-Wno-shorten-64-to-32", - ], - visibility = ["//visibility:public"], - deps = [ - ":gl_calculator_helper", - "//mediapipe/objc:mediapipe_framework_ios", - "//mediapipe/objc:util", - ], -) - objc_library( name = "MPPMetalHelper", srcs = ["MPPMetalHelper.mm"], diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 48c9b181a..d77692164 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -147,7 +147,7 @@ objc_library( visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", ], @@ -173,7 +173,7 @@ objc_library( deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", ], ) From ae44012c0c5a53916f9ee01b3c745868836c784b Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Fri, 18 Nov 2022 08:39:37 -0800 Subject: [PATCH 082/469] Allowing BypassCalculator to accept InputSidePackets. PiperOrigin-RevId: 489483992 --- mediapipe/calculators/core/bypass_calculator.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/calculators/core/bypass_calculator.cc b/mediapipe/calculators/core/bypass_calculator.cc index efc0612ec..4e007329b 100644 --- a/mediapipe/calculators/core/bypass_calculator.cc +++ b/mediapipe/calculators/core/bypass_calculator.cc @@ -111,6 +111,10 @@ class BypassCalculator : public Node { cc->Outputs().Get(id).SetAny(); } } + for (auto id = cc->InputSidePackets().BeginId(); + id != cc->InputSidePackets().EndId(); ++id) { + cc->InputSidePackets().Get(id).SetAny(); + } return absl::OkStatus(); } From e046982a3c6706625c997df50e51e19157624ac7 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 18 Nov 2022 08:44:02 -0800 Subject: [PATCH 083/469] Internal change PiperOrigin-RevId: 489484898 --- .../tensor/audio_to_tensor_calculator.cc | 49 ++++++++++++++++--- .../tensor/audio_to_tensor_calculator.proto | 13 +++++ 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index d0513518a..9cb23a393 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -43,6 +43,7 @@ namespace api2 { namespace { using Options = ::mediapipe::AudioToTensorCalculatorOptions; +using DftTensorFormat = Options::DftTensorFormat; using FlushMode = Options::FlushMode; std::vector HannWindow(int window_size, bool sqrt_hann) { @@ -188,6 +189,8 @@ class AudioToTensorCalculator : public Node { int padding_samples_before_; int padding_samples_after_; FlushMode flush_mode_; + DftTensorFormat dft_tensor_format_; + Timestamp initial_timestamp_ = Timestamp::Unstarted(); int64 cumulative_input_samples_ = 0; Timestamp next_output_timestamp_ = Timestamp::Unstarted(); @@ -273,6 +276,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { } padding_samples_before_ = options.padding_samples_before(); padding_samples_after_ = options.padding_samples_after(); + dft_tensor_format_ = options.dft_tensor_format(); flush_mode_ = options.flush_mode(); RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ @@ -492,14 +496,43 @@ absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block, kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]), timestamp); } - Matrix fft_output_matrix = - Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); - fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); - // The last two elements are the DFT Nyquist values. - fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part - fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part - ASSIGN_OR_RETURN(output_tensor, - ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2})); + switch (dft_tensor_format_) { + case Options::WITH_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(fft_output_matrix, + {2, fft_size_ / 2})); + break; + } + case Options::WITH_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data(), 1, fft_size_); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_ + 2); + fft_output_matrix(1) = 0.0f; // DC imagery part. + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ + 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ + 2) / 2})); + break; + } + case Options::WITHOUT_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ - 2) / 2})); + break; + } + default: + return absl::InvalidArgumentError("Unsupported dft tensor format."); + } + } else { ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(block, {num_channels_, num_samples_})); diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index cff6b2878..aa3c1229c 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -68,4 +68,17 @@ message AudioToTensorCalculatorOptions { } optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX]; + + enum DftTensorFormat { + DFT_TENSOR_FORMAT_UNKNOWN = 0; + // The output dft tensor without dc and nyquist components. + WITHOUT_DC_AND_NYQUIST = 1; + // The output dft tensor contains the nyquist component as the last + // two values. + WITH_NYQUIST = 2; + // The output dft tensor contains the dc component as the first two values + // and the nyquist component as the last two values. + WITH_DC_AND_NYQUIST = 3; + } + optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST]; } From 2f361e2f4791fa774db5cb20dbc888f89c234447 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 08:51:30 -0800 Subject: [PATCH 084/469] Internal change PiperOrigin-RevId: 489486417 --- mediapipe/util/tracking/BUILD | 3 +-- mediapipe/util/tracking/motion_analysis.cc | 2 +- .../util/tracking/region_flow_computation.cc | 16 ++++++---------- .../tracking/region_flow_computation_test.cc | 2 +- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 319e99d5b..3f1ebb353 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -458,7 +458,6 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", ], ) @@ -739,7 +738,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", diff --git a/mediapipe/util/tracking/motion_analysis.cc b/mediapipe/util/tracking/motion_analysis.cc index 0b7678889..5b6a970cf 100644 --- a/mediapipe/util/tracking/motion_analysis.cc +++ b/mediapipe/util/tracking/motion_analysis.cc @@ -791,7 +791,7 @@ void MotionAnalysis::VisualizeBlurAnalysisRegions(cv::Mat* input_view) { region_flow_computation_->ComputeBlurMask(*input_view, &corner_values, &mask); cv::Mat mask_3c; - cv::cvtColor(mask, mask_3c, CV_GRAY2RGB); + cv::cvtColor(mask, mask_3c, cv::COLOR_GRAY2RGB); cv::addWeighted(*input_view, 0.5, mask_3c, 0.5, -128, *input_view); } diff --git a/mediapipe/util/tracking/region_flow_computation.cc b/mediapipe/util/tracking/region_flow_computation.cc index cfd5c23c2..708c868b5 100644 --- a/mediapipe/util/tracking/region_flow_computation.cc +++ b/mediapipe/util/tracking/region_flow_computation.cc @@ -30,6 +30,7 @@ #include "absl/container/node_hash_set.h" #include "absl/memory/memory.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_features2d_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_video_inc.h" @@ -935,12 +936,13 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, // Area based method best for downsampling. // For color images to temporary buffer. cv::Mat& resized = source.channels() == 1 ? dest_frame : *curr_color_image_; - cv::resize(source, resized, resized.size(), 0, 0, CV_INTER_AREA); + cv::resize(source, resized, resized.size(), 0, 0, cv::INTER_AREA); source_ptr = &resized; // Resize feature extraction mask if needed. if (!source_mask.empty()) { dest_mask.create(resized.rows, resized.cols, CV_8UC1); - cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, CV_INTER_NN); + cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, + cv::INTER_NEAREST); } } else if (!source_mask.empty()) { source_mask.copyTo(dest_mask); @@ -954,7 +956,7 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, const int dimension = visual_options.tiny_image_dimension(); data->tiny_image.create(dimension, dimension, type); cv::resize(*source_ptr, data->tiny_image, data->tiny_image.size(), 0, 0, - CV_INTER_AREA); + cv::INTER_AREA); } if (source_ptr->channels() == 1 && @@ -2286,7 +2288,7 @@ void RegionFlowComputation::ExtractFeatures( // Initialize mask from frame's feature extraction mask, by downsampling and // negating the latter mask. if (!data->mask.empty()) { - cv::resize(data->mask, mask, mask.size(), 0, 0, CV_INTER_NN); + cv::resize(data->mask, mask, mask.size(), 0, 0, cv::INTER_NEAREST); for (int y = 0; y < mask.rows; ++y) { uint8* mask_ptr = mask.ptr(y); for (int x = 0; x < mask.cols; ++x) { @@ -2590,12 +2592,6 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, cv::_InputArray input_frame2(data2.pyramid); #endif - // Using old c-interface for OpenCV's 2.2 tracker. - CvTermCriteria criteria; - criteria.type = CV_TERMCRIT_EPS | CV_TERMCRIT_ITER; - criteria.max_iter = options_.tracking_options().tracking_iterations(); - criteria.epsilon = 0.02f; - feature_track_error_.resize(num_features); feature_status_.resize(num_features); if (use_cv_tracking_) { diff --git a/mediapipe/util/tracking/region_flow_computation_test.cc b/mediapipe/util/tracking/region_flow_computation_test.cc index 0ac6dc2a5..435a8e200 100644 --- a/mediapipe/util/tracking/region_flow_computation_test.cc +++ b/mediapipe/util/tracking/region_flow_computation_test.cc @@ -28,7 +28,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" -#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" From 03d388fecffe3734d8f6878f6f0def404065076b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 09:49:23 -0800 Subject: [PATCH 085/469] Add hand landmark named index constants PiperOrigin-RevId: 489498248 --- .../tasks/cc/components/containers/BUILD | 5 ++ .../tasks/cc/components/containers/landmark.h | 48 +++++++++++++ .../tasks/components/containers/BUILD | 12 ++++ .../components/containers/HandLandmark.java | 72 +++++++++++++++++++ .../python/components/containers/landmark.py | 26 +++++++ .../web/components/containers/landmark.d.ts | 25 +++++++ 6 files changed, 188 insertions(+) create mode 100644 mediapipe/tasks/cc/components/containers/landmark.h create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index bd66a0f28..2f5f8be5b 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -49,3 +49,8 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "landmark", + hdrs = ["landmark.h"], +) diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/components/containers/landmark.h new file mode 100644 index 000000000..6fdd294ae --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ + +namespace mediapipe::tasks::components::containers { + +// The 21 hand landmarks. +enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +}; + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..869157295 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -74,6 +74,18 @@ android_library( ], ) +android_library( + name = "handlandmark", + srcs = ["HandLandmark.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + android_library( name = "landmark", srcs = ["Landmark.java"], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java new file mode 100644 index 000000000..da7c4e0ca --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java @@ -0,0 +1,72 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import androidx.annotation.IntDef; + +/** The 21 hand landmarks. */ +public final class HandLandmark { + public static final int NUM_LANDMARKS = 21; + + public static final int WRIST = 0; + public static final int THUMB_CMC = 1; + public static final int THUMB_MCP = 2; + public static final int THUMB_IP = 3; + public static final int THUMB_TIP = 4; + public static final int INDEX_FINGER_MCP = 5; + public static final int INDEX_FINGER_PIP = 6; + public static final int INDEX_FINGER_DIP = 7; + public static final int INDEX_FINGER_TIP = 8; + public static final int MIDDLE_FINGER_MCP = 9; + public static final int MIDDLE_FINGER_PIP = 10; + public static final int MIDDLE_FINGER_DIP = 11; + public static final int MIDDLE_FINGER_TIP = 12; + public static final int RING_FINGER_MCP = 13; + public static final int RING_FINGER_PIP = 14; + public static final int RING_FINGER_DIP = 15; + public static final int RING_FINGER_TIP = 16; + public static final int PINKY_MCP = 17; + public static final int PINKY_PIP = 18; + public static final int PINKY_DIP = 19; + public static final int PINKY_TIP = 20; + + /** Represents a hand landmark type. */ + @IntDef({ + WRIST, + THUMB_CMC, + THUMB_MCP, + THUMB_IP, + THUMB_TIP, + INDEX_FINGER_MCP, + INDEX_FINGER_PIP, + INDEX_FINGER_DIP, + INDEX_FINGER_TIP, + MIDDLE_FINGER_MCP, + MIDDLE_FINGER_PIP, + MIDDLE_FINGER_DIP, + MIDDLE_FINGER_TIP, + RING_FINGER_MCP, + RING_FINGER_PIP, + RING_FINGER_DIP, + RING_FINGER_TIP, + PINKY_MCP, + PINKY_PIP, + PINKY_DIP, + PINKY_TIP, + }) + public @interface HandLandmarkType {} + + private HandLandmark() {} +} diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index dee2a16ad..81b2943dc 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -14,6 +14,7 @@ """Landmark data class.""" import dataclasses +import enum from typing import Optional from mediapipe.framework.formats import landmark_pb2 @@ -120,3 +121,28 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) + + +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index c887303d0..352717a2f 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -33,3 +33,28 @@ export declare interface Landmark { /** Whether this landmark is normalized with respect to the image size. */ normalized: boolean; } + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} From ac212c15070854b407812148739f6e1b72089a75 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Fri, 18 Nov 2022 10:06:47 -0800 Subject: [PATCH 086/469] Internal change PiperOrigin-RevId: 489502255 --- mediapipe/calculators/audio/BUILD | 1 - mediapipe/calculators/core/BUILD | 6 ++---- mediapipe/calculators/image/BUILD | 10 +++++----- mediapipe/calculators/tensor/BUILD | 6 +++--- mediapipe/calculators/tensorflow/BUILD | 14 ++++++++------ mediapipe/calculators/tflite/BUILD | 6 +++--- mediapipe/calculators/util/BUILD | 9 ++++----- mediapipe/calculators/video/BUILD | 4 ++-- mediapipe/framework/BUILD | 4 ---- mediapipe/framework/formats/BUILD | 8 +++++--- mediapipe/framework/formats/motion/BUILD | 4 ++-- mediapipe/framework/profiler/BUILD | 4 ++++ mediapipe/framework/stream_handler/BUILD | 4 ++-- mediapipe/framework/tool/BUILD | 7 ++----- mediapipe/gpu/BUILD | 1 - 15 files changed, 42 insertions(+), 46 deletions(-) diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index ba461e4a7..555f7543f 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -197,7 +197,6 @@ cc_library( ":spectrogram_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index ecd878115..39837fadb 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -341,7 +341,6 @@ cc_test( srcs = ["concatenate_proto_list_calculator_test.cc"], deps = [ ":concatenate_proto_list_calculator", - ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -403,7 +402,6 @@ cc_test( srcs = ["clip_vector_size_calculator_test.cc"], deps = [ ":clip_vector_size_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -956,10 +954,10 @@ cc_library( deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 89e2d371c..c78bc5cf7 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -159,8 +159,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -186,8 +186,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -290,10 +290,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", @@ -361,12 +361,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", + "//mediapipe/util:color_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - "//mediapipe/util:color_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", ] + select({ @@ -630,8 +630,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 3f1278397..4c06df0ff 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -433,6 +433,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":inference_calculator_cc_proto", ":inference_calculator_options_lib", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -794,12 +795,12 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", ] + selects.with_or({ ":compute_shader_unavailable": [], @@ -1279,7 +1280,6 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1378,9 +1378,9 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", "@org_tensorflow//tensorflow/lite:framework", - "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/framework/port:statusor", ] + selects.with_or({ "//mediapipe/gpu:disable_gpu": [], diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index d0dfc12ab..45f64f4f7 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -346,8 +346,8 @@ cc_library( srcs = ["matrix_to_tensor_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":matrix_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -414,7 +414,7 @@ cc_library( "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:detection_cc_proto", # build_cleaner: keep + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:opencv_imgcodecs", @@ -451,8 +451,8 @@ cc_library( srcs = ["tensorflow_inference_calculator.cc"], visibility = ["//visibility:public"], deps = [ - ":tensorflow_session", ":tensorflow_inference_calculator_cc_proto", + ":tensorflow_session", "@com_google_absl//absl/log:check", "//mediapipe/framework:timestamp", "@com_google_absl//absl/base:core_headers", @@ -515,6 +515,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -546,6 +547,7 @@ cc_library( "//mediapipe/framework/deps:clock", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -666,8 +668,8 @@ cc_library( srcs = ["tensor_to_matrix_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":tensor_to_matrix_calculator_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -704,10 +706,10 @@ cc_library( srcs = ["tensor_to_vector_float_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - ":tensor_to_vector_float_calculator_options_cc_proto", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", @@ -1083,7 +1085,6 @@ cc_test( linkstatic = 1, deps = [ ":tensor_to_image_frame_calculator", - ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:image_frame", @@ -1236,6 +1237,7 @@ cc_test( data = [":test_frozen_graph"], linkstatic = 1, deps = [ + ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", ":tensorflow_inference_calculator", ":tensorflow_session_from_frozen_graph_generator", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 2007a4fe1..8edaeee02 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -289,8 +289,8 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_converter_calculator_cc_proto", + "//mediapipe/util/tflite:config", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -410,15 +410,15 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/util/tflite:config", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/deps:file_path", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:location", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", ] + selects.with_or({ diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 3a9ddc36f..24e976a73 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -23,8 +23,8 @@ cc_library( srcs = ["alignment_points_to_rects_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator", - "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -266,8 +266,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", @@ -755,7 +755,6 @@ cc_library( deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:ret_check", @@ -1313,8 +1312,8 @@ cc_library( srcs = ["to_image_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image_frame", @@ -1336,8 +1335,8 @@ cc_library( srcs = ["from_image_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 53d968151..2db3ed252 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -342,12 +342,12 @@ cc_library( "//mediapipe/framework/port:opencv_features2d", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/util/tracking:box_tracker_cc_proto", + "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/util/tracking", "//mediapipe/util/tracking:box_detector", "//mediapipe/util/tracking:box_tracker", - "//mediapipe/util/tracking:box_tracker_cc_proto", - "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util/tracking:tracking_visualization_utilities", ] + select({ "//mediapipe:android": [ diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 8ccdac3b9..e3429f1e9 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1039,7 +1039,6 @@ cc_library( ":graph_service_manager", ":port", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1660,9 +1659,6 @@ cc_test( "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:default_side_packet_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:template_parser", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index c3241d911..e13bb2704 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -133,9 +133,9 @@ cc_library( "//visibility:public", ], deps = [ + ":affine_transform_data_cc_proto", "//mediapipe/framework:port", "//mediapipe/framework:type_map", - "//mediapipe/framework/formats:affine_transform_data_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:point", @@ -209,8 +209,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats/annotation:locus_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -241,6 +241,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":location", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:opencv_imgproc", ], alwayslink = 1, @@ -251,6 +252,7 @@ cc_test( srcs = ["location_opencv_test.cc"], deps = [ ":location_opencv", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:rectangle", ], @@ -346,8 +348,8 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 28e0bfc6a..9819d262c 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -16,10 +16,10 @@ # Description: # Working with dense optical flow in mediapipe. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) + package(default_visibility = ["//visibility:private"]) proto_library( diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 237aa825f..b53a1ac39 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -334,6 +334,10 @@ cc_library( "graph_profiler_stub.h", ], visibility = ["//mediapipe/framework:__pkg__"], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + ], ) cc_test( diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 8771a8773..866a5120e 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package( @@ -20,8 +22,6 @@ package( features = ["-layering_check"], ) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index e54fb2177..52d04b4b1 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,12 +299,12 @@ mediapipe_cc_test( data = [":node_chain_subgraph.proto"], requires_full_emulation = False, deps = [ + ":node_chain_subgraph_cc_proto", ":options_field_util", ":options_registry", ":options_syntax_util", ":options_util", "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:basic_types_registration", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -312,6 +312,7 @@ mediapipe_cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", @@ -486,7 +487,6 @@ cc_library( deps = [ ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", @@ -738,9 +738,7 @@ cc_test( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:graph_service_manager", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework:status_handler", @@ -923,7 +921,6 @@ cc_test( "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework:subgraph", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 27d91f21a..10a8d7fff 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -783,7 +783,6 @@ cc_library( ":image_frame_view", ":shader_util", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_cc_proto", "@com_google_absl//absl/base:core_headers", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", From e2052a6a517fe1d8ce487f46a9856a225644d3f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 11:11:22 -0800 Subject: [PATCH 087/469] Rename embedding postprocessor "configure" method for consistency with classification postprocessor. PiperOrigin-RevId: 489518257 --- .../audio/audio_embedder/audio_embedder_graph.cc | 10 ++++++---- .../processors/embedding_postprocessing_graph.cc | 6 +++--- .../processors/embedding_postprocessing_graph.h | 2 +- .../embedding_postprocessing_graph_test.cc | 14 +++++++------- .../cc/text/text_embedder/text_embedder_graph.cc | 10 ++++++---- .../vision/image_embedder/image_embedder_graph.cc | 10 ++++++---- 6 files changed, 29 insertions(+), 23 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index 7667feaa3..f093b4d25 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -158,10 +158,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio embedding on // audio files. Disables timestamp aggregation by not connecting the diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 880aec5d7..ad4881e12 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -150,7 +150,7 @@ absl::StatusOr> GetHeadNames( } // namespace -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options) { @@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing( // timestamp aggregation is required. // // The recommended way of using this graph is through the GraphBuilder API using -// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more -// details. +// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for +// more details. class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index 58606ed80..889992463 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -58,7 +58,7 @@ namespace processors { // The embedding result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options); diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 84d84d648..163e46ee8 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { options_in.set_quantize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors." "EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( + MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( *model_resources, options, &postprocessing .GetOptions())); diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index 79eedb6b5..c54636ee2 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding result. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 11e25144c..bf0dcf3c7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -151,10 +151,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding results. From 71ae496a2001d1206b792bedd45d4027d7f043c7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 18 Nov 2022 12:10:47 -0800 Subject: [PATCH 088/469] Add AudioEmbedder documentation PiperOrigin-RevId: 489532283 --- .../audio_embedder/audio_embedder_graph.cc | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index f093b4d25..187f11f7f 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -100,6 +100,46 @@ void ConfigureAudioToTensorCalculator( } } // namespace +// An "AudioEmebdderGraph" performs embedding extractions. +// - Accepts CPU audio buffer and outputs embedding results on CPU. +// +// Inputs: +// AUDIO - Matrix +// Audio buffer to perform classification on. +// SAMPLE_RATE - double @Optional +// The sample rate of the corresponding audio data in the "AUDIO" stream. +// If sample rate is not provided, the "AUDIO" stream must carry a time +// series stream header with sample rate info. +// +// Outputs: +// EMBEDDINGS - EmbeddingResult @Optional +// The embedding results aggregated by head. Only produces results if +// the graph if the 'use_stream_mode' option is true. +// TIMESTAMPED_EMBEDDINGS - std::vector @Optional +// The embedding result aggregated by timestamp, then by head. Only +// produces results if the graph if the 'use_stream_mode' option is false. +// +// Example: +// node { +// calculator: "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph" +// input_stream: "AUDIO:audio_in" +// input_stream: "SAMPLE_RATE:sample_rate_in" +// output_stream: "EMBEDDINGS:embeddings_out" +// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out" +// options { +// [mediapipe.tasks.audio.audio_embedder.proto.AudioEmbedderGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// embedder_options { +// l2_normalize: true +// } +// } +// } +// } class AudioEmbedderGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( From 1b594a0310f9c1bc3ece2562455bba0f812efd3a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 18 Nov 2022 12:42:58 -0800 Subject: [PATCH 089/469] Return error status when any tflite input and output tensor doesn't have valid dimensionality information that is needed to allocate Gl/Metal buffer before calling ModifyGraphWithDelegate. PiperOrigin-RevId: 489539740 --- mediapipe/calculators/tensor/BUILD | 2 ++ mediapipe/calculators/tensor/inference_calculator_gl.cc | 8 ++++++++ .../calculators/tensor/inference_calculator_metal.cc | 7 +++++++ 3 files changed, 17 insertions(+) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 4c06df0ff..2a573fc44 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -464,6 +464,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", ], alwayslink = 1, @@ -513,6 +514,7 @@ cc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/util/tflite:config", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index bd8eb3eed..27b8bc23a 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -20,6 +20,7 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" @@ -154,6 +155,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); + gpu_buffers_in_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ @@ -171,6 +176,9 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( // Create and bind output buffers. for (int i = 0; i < output_size_; ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); gpu_buffers_out_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index a85071f3e..750f0456e 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -22,6 +22,7 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" @@ -245,6 +246,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); // Create and bind input buffer. std::vector dims{tensor->dims->data, tensor->dims->data + tensor->dims->size}; @@ -266,6 +270,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( output_shapes_.resize(output_indices.size()); for (int i = 0; i < output_shapes_.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); RET_CHECK(tensor->dims->size <= 4); // Create and bind output buffers. // Channels are always padded to multiple of 4. From 524ac3ca61dc165f23a8d6ce29a9ff36d2fa7e98 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 12:45:56 -0800 Subject: [PATCH 090/469] Internal change for Model Maker PiperOrigin-RevId: 489540387 --- mediapipe/model_maker/python/core/tasks/classifier.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 200726864..f376edffa 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -91,6 +91,10 @@ class Classifier(custom_model.CustomModel): self._history = self._model.fit( x=train_dataset, epochs=self._hparams.epochs, + # `steps_per_epoch` is intentionally set to None in case the dataset + # is not repeated. Otherwise, the training process will stop when the + # dataset is exhausted even if there are epochs remaining. + steps_per_epoch=None, validation_data=validation_dataset, callbacks=self._callbacks) From bbd5da7971aa0d39bbeba638de34ded860bd30b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 17:10:54 -0800 Subject: [PATCH 091/469] Added the gray scale image support for the ImageToTensorCalculator on CPU. PiperOrigin-RevId: 489593917 --- .../tensor/image_to_tensor_calculator_test.cc | 79 ++++++++++++++++--- .../image_to_tensor_converter_opencv.cc | 29 ++++--- .../tensor/image_to_tensor_utils.cc | 7 +- 3 files changed, 93 insertions(+), 22 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 07a5f9fe1..7ea60d98e 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -54,6 +54,13 @@ cv::Mat GetRgba(absl::string_view path) { return rgb; } +cv::Mat GetGray(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + // Image to tensor test template. // No processing/assertions should be done after the function is invoked. void RunTestWithInputImagePacket(const Packet& input_image_packet, @@ -147,29 +154,34 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, ASSERT_THAT(tensor_vec, testing::SizeIs(1)); const Tensor& tensor = tensor_vec[0]; + const int channels = tensor.shape().dims[3]; + ASSERT_TRUE(channels == 1 || channels == 3); auto view = tensor.GetCpuReadView(); cv::Mat tensor_mat; if (output_int_tensor) { if (range_min < 0) { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8SC1 : CV_8SC3, const_cast(view.buffer())); } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8UC1 : CV_8UC3, const_cast(view.buffer())); } } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_32FC1 : CV_32FC3, const_cast(view.buffer())); } cv::Mat result_rgb; auto transformation = GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); - tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale, - transformation.offset); + tensor_mat.convertTo(result_rgb, channels == 1 ? CV_8UC1 : CV_8UC3, + transformation.scale, transformation.offset); cv::Mat diff; cv::absdiff(result_rgb, expected_result, diff); @@ -185,17 +197,27 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, MP_ASSERT_OK(graph.WaitUntilDone()); } +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + CHECK(false) << "Unsupported input image channles: " << image_channels; +} + Packet MakeImageFramePacket(cv::Mat input) { - ImageFrame input_image( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {}); + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); return MakePacket(std::move(input_image)).At(Timestamp(0)); } Packet MakeImagePacket(cv::Mat input) { mediapipe::Image input_image(std::make_shared( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {})); + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); return MakePacket(std::move(input_image)).At(Timestamp(0)); } @@ -429,6 +451,24 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { /*border_mode=*/{}, roi); } +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation.png"), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/{}, roi); +} + TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { mediapipe::NormalizedRect roi; @@ -448,6 +488,25 @@ TEST(ImageToTensorCalculatorTest, /*border_mode=*/BorderMode::kZero, roi); } +TEST(ImageToTensorCalculatorTest, + LargeSubRectKeepAspectWithRotationBorderZeroGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/BorderMode::kZero, roi); +} + TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { mediapipe::NormalizedRect roi; roi.set_x_center(0.5f); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index f910b59f3..76e46f99d 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -48,15 +48,19 @@ class OpenCvProcessor : public ImageToTensorConverter { switch (tensor_type_) { case Tensor::ElementType::kInt8: mat_type_ = CV_8SC3; + mat_gray_type_ = CV_8SC1; break; case Tensor::ElementType::kFloat32: mat_type_ = CV_32FC3; + mat_gray_type_ = CV_32FC1; break; case Tensor::ElementType::kUInt8: mat_type_ = CV_8UC3; + mat_gray_type_ = CV_8UC1; break; default: mat_type_ = -1; + mat_gray_type_ = -1; } } @@ -64,11 +68,13 @@ class OpenCvProcessor : public ImageToTensorConverter { float range_min, float range_max, int tensor_buffer_offset, Tensor& output_tensor) override { - if (input.image_format() != mediapipe::ImageFormat::SRGB && - input.image_format() != mediapipe::ImageFormat::SRGBA) { - return InvalidArgumentError( - absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", - static_cast(input.image_format()))); + const bool is_supported_format = + input.image_format() == mediapipe::ImageFormat::SRGB || + input.image_format() == mediapipe::ImageFormat::SRGBA || + input.image_format() == mediapipe::ImageFormat::GRAY8; + if (!is_supported_format) { + return InvalidArgumentError(absl::StrCat( + "Unsupported format: ", static_cast(input.image_format()))); } // TODO: Remove the check once tensor_buffer_offset > 0 is // supported. @@ -82,17 +88,18 @@ class OpenCvProcessor : public ImageToTensorConverter { const int output_channels = output_shape.dims[3]; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; + const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; default: @@ -137,7 +144,8 @@ class OpenCvProcessor : public ImageToTensorConverter { auto transform, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); - transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); + transformed.convertTo(dst, dst_data_type, transform.scale, + transform.offset); return absl::OkStatus(); } @@ -148,7 +156,7 @@ class OpenCvProcessor : public ImageToTensorConverter { RET_CHECK_EQ(output_shape.dims[0], 1) << "Handling batch dimension not equal to 1 is not implemented in this " "converter."; - RET_CHECK_EQ(output_shape.dims[3], 3) + RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); } @@ -156,6 +164,7 @@ class OpenCvProcessor : public ImageToTensorConverter { enum cv::BorderTypes border_mode_; Tensor::ElementType tensor_type_; int mat_type_; + int mat_gray_type_; }; } // namespace diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index 3f4c05d4e..d27c595b5 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,8 +253,11 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // All of the processors except for Metal expect 3 channels. - return 3; + // The output tensor channel is 1 for the input image with 1 channel; And the + // output tensor channels is 3 for the input image with 3 or 4 channels. + // TODO: Add a unittest here to test the behavior on GPU, i.e. + // failure. + return image.channels() == 1 ? 1 : 3; } absl::StatusOr> GetInputImage( From eb8ef1ace0a2b4c84c04a468478d8eb8463daeed Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Fri, 18 Nov 2022 19:41:05 -0800 Subject: [PATCH 092/469] Use shared_from_this in GlTextureBuffer::GetReadView, GetWriteView This ensures that the callbacks in GlTextureView won't call an expired object, even if user code holds a GlTextureView after releasing the buffer. Note that GlTextureBuffer is not always held by a shared_ptr, but it always is when GpuBuffer calls GetRead/WriteView on it. An alternative solution would have been to have GpuBuffer pass its shared_ptr to the view method, which could have been implemented with some compile-time logic to detect whether the method expects such an argument. However, that doesn't seem necessary. PiperOrigin-RevId: 489611843 --- mediapipe/gpu/gl_texture_buffer.cc | 23 +++++++++++++++++------ mediapipe/gpu/gl_texture_buffer.h | 3 ++- mediapipe/gpu/gpu_buffer_test.cc | 22 ++++++++++++++++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 09703d89d..7f77cd4b3 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -260,13 +260,18 @@ GlTextureView GlTextureBuffer::GetReadView(internal::types, auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); - GlTextureView::DetachFn detach = [this](GlTextureView& texture) { - // Inform the GlTextureBuffer that we have finished accessing its - // contents, and create a consumer sync point. - DidRead(texture.gl_context()->CreateSyncToken()); - }; + GlTextureView::DetachFn detach = + [texbuf = shared_from_this()](GlTextureView& texture) { + // Inform the GlTextureBuffer that we have finished accessing its + // contents, and create a consumer sync point. + texbuf->DidRead(texture.gl_context()->CreateSyncToken()); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), plane, std::move(detach), nullptr); } @@ -276,12 +281,18 @@ GlTextureView GlTextureBuffer::GetWriteView(internal::types, auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); Reuse(); // TODO: the producer wait should probably be part of Reuse in the // case when there are no consumers. GlTextureView::DoneWritingFn done_writing = - [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; + [texbuf = shared_from_this()](const GlTextureView& texture) { + texbuf->ViewDoneWriting(texture); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), plane, nullptr, std::move(done_writing)); } diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index c7643fd1b..f785571a1 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -35,7 +35,8 @@ class GlCalculatorHelperImpl; // Implements a GPU memory buffer as an OpenGL texture. For internal use. class GlTextureBuffer : public internal::GpuBufferStorageImpl< - GlTextureBuffer, internal::ViewProvider> { + GlTextureBuffer, internal::ViewProvider>, + public std::enable_shared_from_this { public: // This is called when the texture buffer is deleted. It is passed a sync // token created at that time on the GlContext. If the GlTextureBuffer has diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 796cb1d9d..145b71806 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -14,6 +14,8 @@ #include "mediapipe/gpu/gpu_buffer.h" +#include + #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -206,5 +208,25 @@ TEST_F(GpuBufferTest, Overwrite) { } } +TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + RunInGlContext([buffer = std::move(buffer)]() mutable { + // This is not a recommended pattern, but let's make sure that we don't + // crash if the buffer is released before the view. The view can hold + // callbacks into its underlying storage. + auto view = buffer.GetReadView(0); + buffer = nullptr; + }); + // We're really checking that we haven't crashed. + EXPECT_TRUE(true); +} + } // anonymous namespace } // namespace mediapipe From e853f04b79bb47e9542f54ba34065de3c5dcbd73 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 18 Nov 2022 19:53:21 -0800 Subject: [PATCH 093/469] Create AudioTaskRunner PiperOrigin-RevId: 489613573 --- .../tasks/audio/core/BaseAudioTaskApi.java | 1 + .../tasks/web/audio/audio_classifier/BUILD | 4 +- .../audio_classifier/audio_classifier.ts | 53 ++++++++--------- mediapipe/tasks/web/audio/core/BUILD | 14 ++++- .../web/audio/core/audio_task_options.d.ts | 21 ------- .../tasks/web/audio/core/audio_task_runner.ts | 58 +++++++++++++++++++ 6 files changed, 98 insertions(+), 53 deletions(-) create mode 100644 mediapipe/tasks/web/audio/core/audio_task_runner.ts diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 8eaf0adcb..2782f8d36 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable { defaultSampleRate = sampleRate; } } + /** * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 9e1fcbc51..498b17845 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -17,14 +17,14 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 5533b0eaa..0c54a4718 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; // tslint:disable:jspb-use-builder-pattern /** Performs audio classification. */ -export class AudioClassifier extends TaskRunner { +export class AudioClassifier extends AudioTaskRunner { private classificationResults: AudioClassifierResult[] = []; - private defaultSampleRate = 48000; private readonly options = new AudioClassifierGraphOptions(); /** @@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the audio classifier. * @@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner { * * @param options The options for the audio classifier. */ - async setOptions(options: AudioClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: AudioClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Sets the sample rate for all calls to `classify()` that omit an explicit - * sample rate. `48000` is used as a default if this method is not called. - * - * @param sampleRate A sample rate (e.g. `44100`). - */ - setDefaultSampleRate(sampleRate: number) { - this.defaultSampleRate = sampleRate; - } - - /** - * Performs audio classification on the provided audio data and waits + * Performs audio classification on the provided audio clip and waits * synchronously for the response. * - * @param audioData An array of raw audio capture data, like - * from a call to getChannelData on an AudioBuffer. + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. * @param sampleRate The sample rate in Hz of the provided audio data. If not * set, defaults to the sample rate set via `setDefaultSampleRate()` or * `48000` if no custom default was set. @@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner { */ classify(audioData: Float32Array, sampleRate?: number): AudioClassifierResult[] { - sampleRate = sampleRate ?? this.defaultSampleRate; + return this.processAudioClip(audioData, sampleRate); + } + /** Sends an audio package to the graph and returns the classifications. */ + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioClassifierResult[] { // Configures the number of samples in the WASM layer. We re-configure the // number of samples and the sample rate for every frame, but ignore other // side effects of this function (such as sending the input side packet and // the input stream header). this.configureAudio( /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); - - const timestamp = performance.now(); - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); - this.addAudioToStream(audioData, timestamp); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); this.classificationResults = []; this.finishProcessing(); diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index ed60f2435..91ebbf524 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -1,6 +1,6 @@ # This package contains options shared by all MediaPipe Audio Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,3 +11,15 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "audio_task_runner", + srcs = ["audio_task_runner.ts"], + deps = [ + ":audio_task_options", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + ], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts index 58a6e55d8..e3068625d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts @@ -16,29 +16,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; -/** - * MediaPipe audio task running mode. A MediaPipe audio task can be run with - * two different modes: - * - audio_clips: The mode for running a mediapipe audio task on independent - * audio clips. - * - audio_stream: The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - * - */ -export type RunningMode = 'audio_clips'|'audio_stream'; - /** The options for configuring a MediaPipe Audio Task. */ export declare interface AudioTaskOptions { /** Options to configure the loading of the model assets. */ baseOptions?: BaseOptions; - - /** - * The running mode of the task. Default to the audio_clips mode. - * Audio tasks have two running modes: - * 1) The mode for running a mediapipe audio task on independent - * audio clips. - * 2) The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - */ - runningMode?: RunningMode; } diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts new file mode 100644 index 000000000..ceff3895b --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -0,0 +1,58 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; + +import {AudioTaskOptions} from './audio_task_options'; + +/** Base class for all MediaPipe Audio Tasks. */ +export abstract class AudioTaskRunner extends TaskRunner { + protected abstract baseOptions?: BaseOptionsProto|undefined; + private defaultSampleRate = 48000; + + /** Configures the shared options of an audio task. */ + async setOptions(options: AudioTaskOptions): Promise { + this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + } + + /** + * Sets the sample rate for API calls that omit an explicit sample rate. + * `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** Sends an audio packet to the graph and awaits results. */ + protected abstract process( + audioData: Float32Array, sampleRate: number, timestampMs: number): T; + + /** Sends a single audio clip to the graph and awaits results. */ + protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + return this.process( + audioData, sampleRate ?? this.defaultSampleRate, performance.now()); + } +} + + From bbcbd5fc6c8fcefaf45da9c126a6f7aa8b6386c2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Nov 2022 04:47:55 -0800 Subject: [PATCH 094/469] Audio Embedder for Web PiperOrigin-RevId: 489669966 --- mediapipe/tasks/web/BUILD | 1 + mediapipe/tasks/web/audio.ts | 4 +- mediapipe/tasks/web/audio/BUILD | 1 + .../tasks/web/audio/audio_embedder/BUILD | 43 ++++ .../audio/audio_embedder/audio_embedder.ts | 211 ++++++++++++++++++ .../audio_embedder_options.d.ts | 22 ++ .../audio_embedder/audio_embedder_result.d.ts | 17 ++ mediapipe/tasks/web/audio/index.ts | 1 + 8 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/web/audio/audio_embedder/BUILD create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index e9703e37a..af76a1fe8 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( srcs = ["audio.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 764fd8393..056426f50 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -15,9 +15,11 @@ */ import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/audio_embedder/audio_embedder'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; -export {AudioClassifier}; +export {AudioClassifier, AudioEmbedder}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..acd7494d7 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -9,5 +9,6 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD new file mode 100644 index 000000000..7d9a994a3 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -0,0 +1,43 @@ +# This contains the MediaPipe Audio Embedder Task. +# +# This task takes audio input and performs embedding. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_embedder", + srcs = ["audio_embedder.ts"], + deps = [ + ":audio_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "audio_embedder_types", + srcs = [ + "audio_embedder_options.d.ts", + "audio_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/audio/core:audio_task_options", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts new file mode 100644 index 000000000..51cb819de --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -0,0 +1,211 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../../../../tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {AudioEmbedderOptions} from './audio_embedder_options'; +import {AudioEmbedderResult} from './audio_embedder_result'; + +export * from './audio_embedder_options'; +export * from './audio_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot +// be changed +// TODO: Change this to `audio_in` to match the name in the CC +// implementation +const AUDIO_STREAM = 'input_audio'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; +const AUDIO_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + +/** Performs embedding extraction on audio. */ +export class AudioEmbedder extends AudioTaskRunner { + private embeddingResults: AudioEmbedderResult[] = []; + private readonly options = new AudioEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new audio embedder from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param audioEmbedderOptions The options for the audio embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + audioEmbedderOptions: AudioEmbedderOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const embedder = await createMediaPipeLib( + AudioEmbedder, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await embedder.setOptions(audioEmbedderOptions); + return embedder; + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return AudioEmbedder.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return AudioEmbedder.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the audio embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio embedder. + */ + override async setOptions(options: AudioEmbedderOptions): Promise { + await super.setOptions(options); + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + this.refreshGraph(); + } + + /** + * Performs embeding extraction on the provided audio clip and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The embedding resuls of the audio + */ + embed(audioData: Float32Array, sampleRate?: number): AudioEmbedderResult[] { + return this.processAudioClip(audioData, sampleRate); + } + + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioEmbedderResult[] { + // Configures the number of samples in the WASM layer. We re-configure the + // number of samples and the sample rate for every frame, but ignore other + // side effects of this function (such as sending the input side packet and + // the input stream header). + this.configureAudio( + /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); + + this.embeddingResults = []; + this.finishProcessing(); + return this.embeddingResults; + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + graphConfig.addOutputStream(TIMESTAMPED_EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(AUDIO_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('AUDIO:' + AUDIO_STREAM); + embedderNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.addOutputStream( + 'TIMESTAMPED_EMBEDDINGS:' + TIMESTAMPED_EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + }); + + this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts new file mode 100644 index 000000000..98f412d0f --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +/** Options to configure the MediaPipe Audio Embedder Task */ +export declare interface AudioEmbedderOptions extends EmbedderOptions, + AudioTaskOptions {} diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts new file mode 100644 index 000000000..13abc28d9 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Embedding, EmbeddingResult as AudioEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index a5083b326..17a908f30 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -15,3 +15,4 @@ */ export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; From 977ee4272e90272fef0ab140036816e83e05c615 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 19 Nov 2022 10:51:20 -0800 Subject: [PATCH 095/469] Add public visibility to the model maker public API. PiperOrigin-RevId: 489701768 --- mediapipe/model_maker/python/text/text_classifier/BUILD | 7 +++++++ .../model_maker/python/vision/gesture_recognizer/BUILD | 7 +++++++ mediapipe/model_maker/python/vision/image_classifier/BUILD | 7 +++++++ 3 files changed, 21 insertions(+) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 0c35e7966..7bb41351e 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -21,9 +21,16 @@ package( licenses(["notice"]) +###################################################################### +# Public target of the MediaPipe Model Maker TextCassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/text/text_classifier/customize for +# more information about the MediaPipe Model Maker TextCassifier APIs. +###################################################################### py_library( name = "text_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":model_options", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b7d334d9c..b9425a181 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -103,9 +103,16 @@ py_library( ], ) +###################################################################### +# Public target of the MediaPipe Model Maker GestureRecognizer APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer/customize +# for more information about the MediaPipe Model Maker GestureRecognizer APIs. +###################################################################### py_library( name = "gesture_recognizer_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":gesture_recognizer", diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index c581d9fbc..29ae189e9 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -21,9 +21,16 @@ package( default_visibility = ["//mediapipe:__subpackages__"], ) +###################################################################### +# Public target of the MediaPipe Model Maker ImageClassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize for +# more information about the MediaPipe Model Maker ImageClassifier APIs. +###################################################################### py_library( name = "image_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":hyperparameters", From a33cb1e05e602cb06b6e6ecdc3a12dad82f5f4e4 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Sat, 19 Nov 2022 21:03:29 -0800 Subject: [PATCH 096/469] Check that Java buffer supports direct access before using it If the buffer is not created with allocateDirect, JNI APIs will return a data pointer of nullptr and a capacity of -1. This can cause a crash when we access it. Also clean up the code to raise exceptions instead of just logging errors and returning nullptr. PiperOrigin-RevId: 489751312 --- .../framework/jni/packet_creator_jni.cc | 171 +++++++++++------- .../framework/jni/packet_getter_jni.cc | 42 +++-- 2 files changed, 133 insertions(+), 80 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 250d7c938..2d5447401 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -17,6 +17,8 @@ #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/camera_intrinsics.h" #include "mediapipe/framework/formats/image.h" @@ -107,17 +109,18 @@ absl::StatusOr CreateGpuBuffer( // Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java // ByteBuffer. -std::unique_ptr CreateImageFrameFromByteBuffer( - JNIEnv* env, jobject byte_buffer, jint width, jint height, - mediapipe::ImageFormat::Format format) { +absl::StatusOr> +CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, + jint height, + mediapipe::ImageFormat::Format format) { switch (format) { case mediapipe::ImageFormat::SRGBA: case mediapipe::ImageFormat::SRGB: case mediapipe::ImageFormat::GRAY8: break; default: - LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8."; - return nullptr; + return absl::InvalidArgumentError( + "Format must be either SRGBA, SRGB, or GRAY8."); } auto image_frame = std::make_unique( @@ -125,25 +128,30 @@ std::unique_ptr CreateImageFrameFromByteBuffer( mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + const int num_channels = image_frame->NumberOfChannels(); const int expected_buffer_size = num_channels == 1 ? width * height : image_frame->PixelDataSize(); - if (buffer_size != expected_buffer_size) { - if (num_channels != 1) - LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; - return nullptr; - } + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << (num_channels != 1 + ? "The input image buffer should have 4 bytes alignment. " + : "") + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << expected_buffer_size + << ", Image width: " << width; // Copy buffer data to image frame's pixel_data_. if (num_channels == 1) { const int width_step = image_frame->WidthStep(); - const char* src_row = - reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); + const char* src_row = reinterpret_cast(buffer_data); char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); for (int i = height; i > 0; --i) { std::memcpy(dst_row, src_row, width); @@ -152,7 +160,6 @@ std::unique_ptr CreateImageFrameFromByteBuffer( } } else { // 3 and 4 channels. - const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); std::memcpy(image_frame->MutablePixelData(), buffer_data, image_frame->PixelDataSize()); } @@ -176,77 +183,100 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } +absl::StatusOr> CreateRgbImageFromRgba( + JNIEnv* env, jobject byte_buffer, jint width, jint height) { + const uint8_t* rgba_data = + static_cast(env->GetDirectBufferAddress(byte_buffer)); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + + const int expected_buffer_size = width * height * 4; + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << expected_buffer_size + << ", Image width: " << width; + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::SRGB, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, + image_frame->MutablePixelData(), + image_frame->WidthStep()); + return image_frame; +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const uint8_t* rgba_data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != width * height * 4) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << width * height * 4 - << ", Image width: " << width; - return 0L; - } - mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, - image_frame->MutablePixelData(), - image_frame->WidthStep()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const void* data = env->GetDirectBufferAddress(byte_buffer); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + // TODO: merge this case with CreateImageFrameFromByteBuffer. + auto image_frame_or = + [&]() -> absl::StatusOr> { + const void* data = env->GetDirectBufferAddress(byte_buffer); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "input buffer does not support direct access"); + } + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::VEC32F1, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize()) + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << image_frame->PixelDataSize() + << ", Image width: " << width; + std::memcpy(image_frame->MutablePixelData(), data, + image_frame->PixelDataSize()); + return image_frame; + }(); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -291,6 +321,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)( jint num_samples) { const uint8_t* audio_sample = reinterpret_cast(env->GetDirectBufferAddress(data)); + if (!audio_sample) { + ThrowIfError(env, absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It " + "should be created using allocateDirect.")); + return 0L; + } mediapipe::Packet packet = createAudioPacket(audio_sample, num_samples, num_channels); return CreatePacketWithContext(context, packet); @@ -360,8 +396,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, jfloatArray data) { if (env->GetArrayLength(data) != rows * cols) { - LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = " - << rows * cols; + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Please check the matrix data size, has to be rows * cols = ", + rows * cols))); return 0L; } std::unique_ptr matrix(new mediapipe::Matrix(rows, cols)); @@ -392,16 +430,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( format = mediapipe::ImageFormat::GRAY8; break; default: - LOG(ERROR) << "Channels must be either 1, 3, or 4."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Channels must be either 1, 3, or 4, but are ", + num_channels))); return 0L; } - auto image_frame = + auto image_frame_or = CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = - mediapipe::MakePacket(std::move(image_frame)); + mediapipe::MakePacket(*std::move(image_frame_or)); return CreatePacketWithContext(context, packet); } @@ -502,7 +542,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)( jbyte* data_ref = env->GetByteArrayElements(data, nullptr); auto options = absl::make_unique(); if (!options->ParseFromArray(data_ref, count)) { - LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Parsing binary-encoded CalculatorOptions failed."))); return 0L; } mediapipe::Packet packet = mediapipe::Adopt(options.release()); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index c215dd929..737f6db72 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -14,6 +14,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" @@ -299,34 +300,38 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( : GetFromNativeHandle(packet); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } // Assume byte buffer stores pixel data contiguously. const int expected_buffer_size = image.Width() * image.Height() * image.ByteDepth() * image.NumberOfChannels(); if (buffer_size != expected_buffer_size) { - LOG(ERROR) << "Expected buffer size " << expected_buffer_size - << " got: " << buffer_size << ", width " << image.Width() - << ", height " << image.Height() << ", channels " - << image.NumberOfChannels(); + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Expected buffer size ", expected_buffer_size, + " got: ", buffer_size, ", width ", image.Width(), ", height ", + image.Height(), ", channels ", image.NumberOfChannels()))); return false; } switch (image.ByteDepth()) { case 1: { - uint8* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint8* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 2: { - uint16* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint16* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 4: { - float* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + float* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } @@ -351,12 +356,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( uint8_t* rgba_data = static_cast(env->GetDirectBufferAddress(byte_buffer)); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } if (buffer_size != image.Width() * image.Height() * 4) { - LOG(ERROR) << "Buffer size has to be width*height*4\n" - << "Image width: " << image.Width() - << ", Image height: " << image.Height() - << ", Buffer size: " << buffer_size << ", Buffer size needed: " - << image.Width() * image.Height() * 4; + ThrowIfError(env, + absl::InvalidArgumentError(absl::StrCat( + "Buffer size has to be width*height*4\n" + "Image width: ", + image.Width(), ", Image height: ", image.Height(), + ", Buffer size: ", buffer_size, ", Buffer size needed: ", + image.Width() * image.Height() * 4))); return false; } mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(), From bdf4078e89cb11e01da0c5eda6322a22ad74e127 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 19 Nov 2022 21:12:23 -0800 Subject: [PATCH 097/469] Internal change PiperOrigin-RevId: 489752009 --- mediapipe/model_maker/python/core/utils/BUILD | 1 + .../python/core/utils/model_util_test.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 12fef631f..492bba0a9 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -45,6 +45,7 @@ py_test( name = "model_util_test", srcs = ["model_util_test.py"], deps = [ + ":file_util", ":model_util", ":quantization", ":test_util", diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 05c6ffe3f..f0020db25 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -14,10 +14,12 @@ import os from typing import Optional +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import test_util @@ -25,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - def test_load_keras_model(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_keras_model(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') model.save(saved_model_path) + # model_util.load_keras_model takes in a relative path to files within the + # model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = saved_model_path loaded_model = model_util.load_keras_model(saved_model_path) input_tensors = test_util.create_random_sample(size=[1, input_dim]) @@ -37,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) - def test_load_tflite_model_buffer(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_tflite_model_buffer(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_model = model_util.convert_to_tflite(model) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) - + # model_util.load_tflite_model_buffer takes in a relative path to files + # within the model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = tflite_file tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) test_util.test_tflite( keras_model=model, From a367753eda595f01a60e4ccb12845f2675cb37c5 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Sun, 20 Nov 2022 10:39:59 -0800 Subject: [PATCH 098/469] Internal change PiperOrigin-RevId: 489824381 --- .../vision/gesture_recognizer/gesture_recognizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 39272cbbc..9cee88362 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,7 +14,6 @@ import io import os -import random import tempfile from unittest import mock as unittest_mock import zipfile @@ -27,6 +26,7 @@ from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.tasks.python.test import test_utils _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +tf.keras.backend.experimental.enable_tf_random_generator() class GestureRecognizerTest(tf.test.TestCase): @@ -42,7 +42,7 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - random.seed(1234) + tf.keras.utils.set_random_seed(87654321) all_data = self._load_data() # Splits data, 90% data for training, 10% for validation self._train_data, self._validation_data = all_data.split(0.9) From 6cf464636b00fb5039bf705319ffe09408d207b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 20 Nov 2022 14:24:21 -0800 Subject: [PATCH 099/469] Internal change PiperOrigin-RevId: 489842199 --- mediapipe/tasks/BUILD | 7 ++ .../tasks/cc/audio/audio_classifier/BUILD | 53 ++++++----- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 55 ++++++------ mediapipe/tasks/cc/audio/core/BUILD | 1 + .../tasks/cc/components/containers/BUILD | 2 +- .../tasks/cc/components/processors/BUILD | 2 + mediapipe/tasks/cc/core/BUILD | 4 +- mediapipe/tasks/cc/text/text_classifier/BUILD | 51 ++++++----- mediapipe/tasks/cc/text/text_embedder/BUILD | 3 + mediapipe/tasks/cc/vision/core/BUILD | 2 + .../tasks/cc/vision/gesture_recognizer/BUILD | 90 ++++++++++--------- .../tasks/cc/vision/hand_landmarker/BUILD | 72 ++++++++------- .../tasks/cc/vision/image_classifier/BUILD | 49 +++++----- .../tasks/cc/vision/image_embedder/BUILD | 49 +++++----- .../tasks/cc/vision/image_segmenter/BUILD | 6 +- .../tasks/cc/vision/object_detector/BUILD | 65 +++++++------- 16 files changed, 278 insertions(+), 233 deletions(-) diff --git a/mediapipe/tasks/BUILD b/mediapipe/tasks/BUILD index 242a88cfc..98ddd5777 100644 --- a/mediapipe/tasks/BUILD +++ b/mediapipe/tasks/BUILD @@ -21,3 +21,10 @@ package_group( "//mediapipe/tasks/...", ], ) + +package_group( + name = "users", + includes = [ + ":internal", + ], +) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 1955adfe7..a817bcc3b 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -16,6 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Classifier +# https://developers.google.com/mediapipe/solutions/audio/audio_classifier +cc_library( + name = "audio_classifier", + srcs = ["audio_classifier.cc"], + hdrs = ["audio_classifier.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_classifier_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_classifier_graph", srcs = ["audio_classifier_graph.cc"], @@ -52,28 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_classifier", - srcs = ["audio_classifier.cc"], - hdrs = ["audio_classifier.h"], - deps = [ - ":audio_classifier_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index b982ef39a..adba28e6a 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -16,6 +16,36 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Embedder +# https://developers.google.com/mediapipe/solutions/audio/audio_embedder +cc_library( + name = "audio_embedder", + srcs = ["audio_embedder.cc"], + hdrs = ["audio_embedder.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_embedder_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_embedder_graph", srcs = ["audio_embedder_graph.cc"], @@ -51,29 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_embedder", - srcs = ["audio_embedder.cc"], - hdrs = ["audio_embedder.h"], - deps = [ - ":audio_embedder_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:embedding_result", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedder_options", - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:cosine_similarity", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/core/BUILD b/mediapipe/tasks/cc/audio/core/BUILD index 93362fd3d..016faa10f 100644 --- a/mediapipe/tasks/cc/audio/core/BUILD +++ b/mediapipe/tasks/cc/audio/core/BUILD @@ -19,6 +19,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 2f5f8be5b..dec977fb8 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 7845a3dae..32a628db7 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -20,6 +20,7 @@ cc_library( name = "classifier_options", srcs = ["classifier_options.cc"], hdrs = ["classifier_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], ) @@ -67,6 +68,7 @@ cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], hdrs = ["embedder_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"], ) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index f14457073..202f3ea3c 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -22,9 +22,7 @@ cc_library( name = "base_options", srcs = ["base_options.cc"], hdrs = ["base_options.h"], - visibility = [ - "//mediapipe/tasks:internal", - ], + visibility = ["//visibility:public"], deps = [ ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 52b0c0e4b..01adc9fc3 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -16,6 +16,33 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Classifier +# https://developers.google.com/mediapipe/solutions/text/text_classifier +cc_library( + name = "text_classifier", + srcs = ["text_classifier.cc"], + hdrs = ["text_classifier.h"], + visibility = ["//visibility:public"], + deps = [ + ":text_classifier_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers:category", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "text_classifier_graph", srcs = ["text_classifier_graph.cc"], @@ -41,30 +68,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "text_classifier", - srcs = ["text_classifier.cc"], - hdrs = ["text_classifier.h"], - deps = [ - ":text_classifier_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/tasks/cc/components/containers:category", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", - "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - cc_test( name = "text_classifier_test", srcs = ["text_classifier_test.cc"], diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index e2e16c9c1..27c9cb730 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -16,10 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Embedder +# https://developers.google.com/mediapipe/solutions/text/text_embedder cc_library( name = "text_embedder", srcs = ["text_embedder.cc"], hdrs = ["text_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_graph", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index e8e197a1d..1f5ab5faf 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -19,11 +19,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( name = "image_processing_options", hdrs = ["image_processing_options.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/cc/components/containers:rect", ], diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 75289b1e8..7b144e7aa 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -18,6 +18,52 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Gesture Recognizer +# https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer +cc_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.cc"], + hdrs = ["gesture_recognizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":gesture_recognizer_graph", + ":gesture_recognizer_result", + ":hand_gesture_recognizer_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + cc_library( name = "handedness_util", srcs = ["handedness_util.cc"], @@ -127,51 +173,9 @@ cc_library( cc_library( name = "gesture_recognizer_result", hdrs = ["gesture_recognizer_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) - -cc_library( - name = "gesture_recognizer", - srcs = ["gesture_recognizer.cc"], - hdrs = ["gesture_recognizer.h"], - deps = [ - ":gesture_recognizer_graph", - ":gesture_recognizer_result", - ":hand_gesture_recognizer_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5c5073fc2..3b869eab4 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -18,6 +18,43 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Hand Landmarker +# https://developers.google.com/mediapipe/solutions/vision/hand_landmarker +cc_library( + name = "hand_landmarker", + srcs = ["hand_landmarker.cc"], + hdrs = ["hand_landmarker.h"], + visibility = ["//visibility:public"], + deps = [ + ":hand_landmarker_graph", + ":hand_landmarker_result", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], @@ -113,44 +150,11 @@ cc_library( cc_library( name = "hand_landmarker_result", hdrs = ["hand_landmarker_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) -cc_library( - name = "hand_landmarker", - srcs = ["hand_landmarker.cc"], - hdrs = ["hand_landmarker.h"], - deps = [ - ":hand_landmarker_graph", - ":hand_landmarker_result", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], -) - # TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index b59d8d682..2b93aa262 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_classifier_graph", - srcs = ["image_classifier_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Classifier +# https://developers.google.com/mediapipe/solutions/vision/image_classifier cc_library( name = "image_classifier", srcs = ["image_classifier.cc"], hdrs = ["image_classifier.h"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_graph", "//mediapipe/framework:packet", @@ -69,4 +49,27 @@ cc_library( ], ) +cc_library( + name = "image_classifier_graph", + srcs = ["image_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index ea7f40261..8fdb97ccd 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_embedder_graph", - srcs = ["image_embedder_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Embedder +# https://developers.google.com/mediapipe/solutions/vision/image_embedder cc_library( name = "image_embedder", srcs = ["image_embedder.cc"], hdrs = ["image_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_graph", "//mediapipe/framework/api2:builder", @@ -67,4 +47,27 @@ cc_library( ], ) +cc_library( + name = "image_embedder_graph", + srcs = ["image_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 7206a45ea..595eef568 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,13 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Image Segmenter +# https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], - visibility = [ - "//mediapipe/tasks:internal", - ], + visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 8220d8b7f..b8002fa96 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -16,6 +16,41 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Object Detector +# https://developers.google.com/mediapipe/solutions/vision/object_detector +cc_library( + name = "object_detector", + srcs = ["object_detector.cc"], + hdrs = ["object_detector.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":object_detector_graph", + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "object_detector_graph", srcs = ["object_detector_graph.cc"], @@ -56,34 +91,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "object_detector", - srcs = ["object_detector.cc"], - hdrs = ["object_detector.h"], - deps = [ - ":object_detector_graph", - "//mediapipe/calculators/core:concatenate_vector_calculator", - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: This test fails in OSS From 3ac7f6a216c12d617edd6549ace59f4f76e085c7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Sun, 20 Nov 2022 19:30:05 -0800 Subject: [PATCH 100/469] Simplify image creation in PacketCreator Use more existing functions, remove redundant code, remove direct use of RuntimeException. PiperOrigin-RevId: 489868983 --- .../mediapipe/framework/PacketCreator.java | 53 +++++---- .../framework/jni/packet_creator_jni.cc | 104 +++++------------- .../framework/jni/packet_creator_jni.h | 2 +- 3 files changed, 64 insertions(+), 95 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index d93eea7b5..04265cab5 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -55,7 +55,11 @@ public class PacketCreator { public Packet createRgbImage(ByteBuffer buffer, int width, int height) { int widthStep = (((width * 3) + 3) / 4) * 4; if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + widthStep * height + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -123,7 +127,11 @@ public class PacketCreator { */ public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) { if (width * height * 4 != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + width * height * 4); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -136,7 +144,7 @@ public class PacketCreator { */ public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) { if (width * height != buffer.capacity()) { - throw new RuntimeException( + throw new IllegalArgumentException( "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); } return Packet.create( @@ -150,7 +158,11 @@ public class PacketCreator { */ public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -163,7 +175,11 @@ public class PacketCreator { */ public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -354,25 +370,24 @@ public class PacketCreator { *

For 3 and 4 channel images, the pixel rows should have 4-byte alignment. */ public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) { + int widthStep; if (numChannels == 4) { - if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); - } + widthStep = width * 4; } else if (numChannels == 3) { - int widthStep = (((width * 3) + 3) / 4) * 4; - if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); - } + widthStep = (((width * 3) + 3) / 4) * 4; } else if (numChannels == 1) { - if (width * height != buffer.capacity()) { - throw new RuntimeException( - "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); - } + widthStep = width; } else { - throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels); + throw new IllegalArgumentException("Channels should be: 1, 3, or 4, but is " + numChannels); + } + int expectedSize = widthStep * height; + if (buffer.capacity() != expectedSize) { + throw new IllegalArgumentException( + "The size of the buffer should be: " + expectedSize + " but is " + buffer.capacity()); } return Packet.create( - nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels)); + nativeCreateCpuImage( + mediapipeGraph.getNativeHandle(), buffer, width, height, widthStep, numChannels)); } /** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */ @@ -430,7 +445,7 @@ public class PacketCreator { long context, int name, int width, int height, TextureReleaseCallback releaseCallback); private native long nativeCreateCpuImage( - long context, ByteBuffer buffer, int width, int height, int numChannels); + long context, ByteBuffer buffer, int width, int height, int rowBytes, int numChannels); private native long nativeCreateInt32Array(long context, int[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 2d5447401..46ea1ce41 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -111,22 +111,8 @@ absl::StatusOr CreateGpuBuffer( // ByteBuffer. absl::StatusOr> CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, - jint height, + jint height, jint width_step, mediapipe::ImageFormat::Format format) { - switch (format) { - case mediapipe::ImageFormat::SRGBA: - case mediapipe::ImageFormat::SRGB: - case mediapipe::ImageFormat::GRAY8: - break; - default: - return absl::InvalidArgumentError( - "Format must be either SRGBA, SRGB, or GRAY8."); - } - - auto image_frame = std::make_unique( - format, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); if (buffer_data == nullptr || buffer_size < 0) { @@ -135,34 +121,19 @@ CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, "using allocateDirect."); } - const int num_channels = image_frame->NumberOfChannels(); - const int expected_buffer_size = - num_channels == 1 ? width * height : image_frame->PixelDataSize(); - + const int expected_buffer_size = height * width_step; RET_CHECK_EQ(buffer_size, expected_buffer_size) - << (num_channels != 1 - ? "The input image buffer should have 4 bytes alignment. " - : "") - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; - // Copy buffer data to image frame's pixel_data_. - if (num_channels == 1) { - const int width_step = image_frame->WidthStep(); - const char* src_row = reinterpret_cast(buffer_data); - char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); - for (int i = height; i > 0; --i) { - std::memcpy(dst_row, src_row, width); - src_row += width; - dst_row += width_step; - } - } else { - // 3 and 4 channels. - std::memcpy(image_frame->MutablePixelData(), buffer_data, - image_frame->PixelDataSize()); - } + auto image_frame = std::make_unique(); + // TODO: we could retain the buffer with a special deleter and use + // the data directly without a copy. May need a new Java API since existing + // code might expect to be able to overwrite the buffer after creating an + // ImageFrame from it. + image_frame->CopyPixelData( + format, width, height, width_step, static_cast(buffer_data), + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); return image_frame; } @@ -183,8 +154,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); + // We require 4-byte alignment. See Java method. + constexpr int kAlignment = 4; + int width_step = ((width * 3 - 1) | (kAlignment - 1)) + 1; + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, + width_step, mediapipe::ImageFormat::SRGB); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); @@ -204,10 +179,8 @@ absl::StatusOr> CreateRgbImageFromRgba( const int expected_buffer_size = width * height * 4; RET_CHECK_EQ(buffer_size, expected_buffer_size) - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; auto image_frame = absl::make_unique( mediapipe::ImageFormat::SRGB, width, height, @@ -232,7 +205,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); + env, byte_buffer, width, height, width, mediapipe::ImageFormat::GRAY8); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); @@ -242,28 +215,9 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - // TODO: merge this case with CreateImageFrameFromByteBuffer. auto image_frame_or = - [&]() -> absl::StatusOr> { - const void* data = env->GetDirectBufferAddress(byte_buffer); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (data == nullptr || buffer_size < 0) { - return absl::InvalidArgumentError( - "input buffer does not support direct access"); - } - - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize()) - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - return image_frame; - }(); + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::VEC32F1); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); @@ -272,10 +226,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::SRGBA); if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -417,7 +371,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels) { + jint height, jint width_step, jint num_channels) { mediapipe::ImageFormat::Format format; switch (num_channels) { case 4: @@ -436,8 +390,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( return 0L; } - auto image_frame_or = - CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); + auto image_frame_or = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, width_step, format); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index d6f44b0a3..b3b1043fb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -99,7 +99,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels); + jint height, jint width_step, jint num_channels); JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, From 13c6b9a8c6ce6fc9d0e34316821d497bb7f4f9f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 20 Nov 2022 22:18:49 -0800 Subject: [PATCH 101/469] Allow kernel cache path to be specified without trailing path delimiter PiperOrigin-RevId: 489891079 --- .../calculators/tensor/inference_calculator_gl_advanced.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index ad5df849f..c2c723402 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -241,9 +241,9 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { - cached_kernel_filename_ = gpu_delegate_options.cached_kernel_path() + - mediapipe::File::Basename(options.model_path()) + - ".ker"; + cached_kernel_filename_ = mediapipe::file::JoinPath( + gpu_delegate_options.cached_kernel_path(), + mediapipe::File::Basename(options.model_path()) + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = From 7acbf557a1294e3809e8671ac769c855dd3336c4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 01:55:49 -0800 Subject: [PATCH 102/469] Cleanup after migration to new classification output format. PiperOrigin-RevId: 489921603 --- .../tasks/cc/components/calculators/BUILD | 1 - .../classification_aggregation_calculator.cc | 68 +--- .../cc/components/containers/proto/BUILD | 6 - .../containers/proto/category.proto | 41 --- .../containers/proto/classifications.proto | 17 +- .../classification_postprocessing_graph.cc | 9 - .../classification_postprocessing_graph.h | 3 - ...lassification_postprocessing_graph_test.cc | 322 ------------------ .../text_classifier/text_classifier_graph.cc | 27 +- .../image_classifier_graph.cc | 9 - .../com/google/mediapipe/tasks/text/BUILD | 1 - .../com/google/mediapipe/tasks/vision/BUILD | 1 - .../tasks/python/components/containers/BUILD | 2 +- .../python/components/containers/category.py | 16 +- .../containers/classification_result.py | 15 +- 15 files changed, 23 insertions(+), 515 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/containers/proto/category.proto diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 1f726a018..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 1a83fdad2..ad2c668c3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,14 +25,12 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into either a ClassificationResult object // representing the classification results aggregated by classifier head, or @@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications; // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example without timestamp aggregation: // node { @@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node { ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); std::vector ConvertToTimestampedClassificationResults( CalculatorContext* cc); - // TODO: deprecate this function once migration is over. - ClassificationResult LegacyConvertToClassificationResult( - CalculatorContext* cc); }; absl::Status ClassificationAggregationCalculator::UpdateContract( @@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract( << "The size of classifications input streams should match the " "size of head names specified in the calculator options"; } - // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if - // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is - // not connected. All dependent tasks must be updated to use these outputs - // first. + if (kTimestampsIn(cc).IsConnected()) { + RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected()); + } else { + RET_CHECK(kClassificationsOut(cc).IsConnected()); + } return absl::OkStatus(); } @@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process( if (kTimestampsIn(cc).IsEmpty()) { return absl::OkStatus(); } - classification_result = LegacyConvertToClassificationResult(cc); kTimestampedClassificationsOut(cc).Send( ConvertToTimestampedClassificationResults(cc)); } else { - classification_result = LegacyConvertToClassificationResult(cc); kClassificationsOut(cc).Send(ConvertToClassificationResult(cc)); } kClassificationResultOut(cc).Send(classification_result); @@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults( return results; } -ClassificationResult -ClassificationAggregationCalculator::LegacyConvertToClassificationResult( - CalculatorContext* cc) { - ClassificationResult result; - Timestamp first_timestamp(0); - std::vector timestamps; - if (time_aggregation_enabled_) { - timestamps = kTimestampsIn(cc).Get(); - first_timestamp = timestamps[0]; - } else { - timestamps = {cc->InputTimestamp()}; - } - for (Timestamp timestamp : timestamps) { - int count = cached_classifications_[timestamp.Value()].size(); - for (int i = 0; i < count; ++i) { - Classifications* c; - if (result.classifications_size() <= i) { - c = result.add_classifications(); - if (!head_names_.empty()) { - c->set_head_index(i); - c->set_head_name(head_names_[i]); - } - } else { - c = result.mutable_classifications(i); - } - auto* entry = c->add_entries(); - for (const auto& elem : - cached_classifications_[timestamp.Value()][i].classification()) { - auto* category = entry->add_categories(); - if (elem.has_index()) { - category->set_index(elem.index()); - } - if (elem.has_score()) { - category->set_score(elem.score()); - } - if (elem.has_label()) { - category->set_category_name(elem.label()); - } - if (elem.has_display_name()) { - category->set_display_name(elem.display_name()); - } - } - entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / - 1000); - } - } - return result; -} - MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator); } // namespace api2 diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 7b455c0c4..27d2357b5 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], -) - mediapipe_proto_library( name = "classifications_proto", srcs = ["classifications.proto"], deps = [ - ":category_proto", "//mediapipe/framework/formats:classification_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto deleted file mode 100644 index 412e71428..000000000 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto2"; - -package mediapipe.tasks.components.containers.proto; - -option java_package = "com.google.mediapipe.tasks.components.containers.proto"; -option java_outer_classname = "CategoryProto"; - -// TODO: deprecate this message once migration is over. -// A single classification result. -message Category { - // The index of the category in the corresponding label map, usually packed in - // the TFLite Model Metadata [1]. - // - // [1]: https://www.tensorflow.org/lite/convert/metadata - optional int32 index = 1; - // The score for this category, e.g. (but not necessarily) a probability in - // [0,1]. - optional float score = 2; - // A human readable name of the category filled from the label map. - optional string display_name = 3; - // An ID for the category, not necessarily human-readable, e.g. a Google - // Knowledge Graph ID [1], filled from the label map. - // - // [1]: https://developers.google.com/knowledge-graph - optional string category_name = 4; -} diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index f098ed0e4..2b2306829 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -18,27 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; import "mediapipe/framework/formats/classification.proto"; -import "mediapipe/tasks/cc/components/containers/proto/category.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; -// TODO: deprecate this message once migration is over. -// List of predicted categories with an optional timestamp. -message ClassificationEntry { - // The array of predicted categories, usually sorted by descending scores, - // e.g., from high to low probability. - repeated Category categories = 1; - // The optional timestamp (in milliseconds) associated to the classifcation - // entry. This is useful for time series use cases, e.g., audio - // classification. - optional int64 timestamp_ms = 2; -} - // Classifications for a given classifier head, i.e. for a given output tensor. message Classifications { - // TODO: deprecate this field once migration is over. - repeated ClassificationEntry entries = 1; // The classification results for this head. optional mediapipe.ClassificationList classification_list = 4; // The index of the classifier head these categories refer to. This is useful @@ -48,6 +33,8 @@ message Classifications { // name. // TODO: Add github link to metadata_schema.fbs. optional string head_name = 3; + // Reserved fields. + reserved 1; } // Classifications for a given classifier model. diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 0fb62afaf..5a0472f5c 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorsTag[] = "TENSORS"; @@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; // Struct holding the different output streams produced by the graph. struct ClassificationPostprocessingOutputStreams { - Source classification_result; Source classifications; Source> timestamped_classifications; }; @@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph( // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // The recommended way of using this graph is through the GraphBuilder API // using the 'ConfigureClassificationPostprocessingGraph()' function. See header @@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.timestamped_classifications >> @@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { // Connects output. ClassificationPostprocessingOutputStreams output_streams{ - /*classification_result=*/result_aggregation - [Output(kClassificationResultTag)], /*classifications=*/ result_aggregation[Output(kClassificationsTag)], /*timestamped_classifications=*/ diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 48575ceb0..03ae91130 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -58,9 +58,6 @@ namespace processors { // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index d4728e725..8eb6f3c3b 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsName[] = "tensors"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsName[] = "timestamps"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; -constexpr char kClassificationResultName[] = "classification_result"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kClassificationsName[] = "classifications"; constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; @@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { })pb")})); } -// TODO: remove these tests once migration is over. -class LegacyPostprocessingTest : public tflite_shims::testing::Test { - protected: - absl::StatusOr BuildGraph( - absl::string_view model_name, const proto::ClassifierOptions& options, - bool connect_timestamps = false) { - ASSIGN_OR_RETURN(auto model_resources, - CreateModelResourcesForModel(model_name)); - - Graph graph; - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.processors." - "ClassificationPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( - *model_resources, options, - &postprocessing - .GetOptions())); - graph[Input>(kTensorsTag)].SetName(kTensorsName) >> - postprocessing.In(kTensorsTag); - if (connect_timestamps) { - graph[Input>(kTimestampsTag)].SetName( - kTimestampsName) >> - postprocessing.In(kTimestampsTag); - } - postprocessing.Out(kClassificationResultTag) - .SetName(kClassificationResultName) >> - graph[Output(kClassificationResultTag)]; - - MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); - ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( - kClassificationResultName)); - MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); - return poller; - } - - template - void AddTensor( - const std::vector& tensor, const Tensor::ElementType& element_type, - const Tensor::QuantizationParameters& quantization_parameters = {}) { - tensors_->emplace_back(element_type, - Tensor::Shape{1, static_cast(tensor.size())}, - quantization_parameters); - auto view = tensors_->back().GetCpuWriteView(); - T* buffer = view.buffer(); - std::copy(tensor.begin(), tensor.end(), buffer); - } - - absl::Status Run( - std::optional> aggregation_timestamps = std::nullopt, - int timestamp = 0) { - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); - // Reset tensors for future calls. - tensors_ = absl::make_unique>(); - if (aggregation_timestamps.has_value()) { - auto packet = absl::make_unique>(); - for (const auto& timestamp : *aggregation_timestamps) { - packet->emplace_back(Timestamp(timestamp)); - } - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); - } - return absl::OkStatus(); - } - - absl::StatusOr GetClassificationResult( - OutputStreamPoller& poller) { - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); - MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); - - Packet packet; - if (!poller.Next(&packet)) { - return absl::InternalError("Unable to get output packet"); - } - auto result = packet.Get(); - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); - return result; - } - - private: - CalculatorGraph calculator_graph_; - std::unique_ptr> tensors_ = - absl::make_unique>(); -}; - -TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - options.set_score_threshold(0.5); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 18; - tensor[2] = 16; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto(R"pb(classifications { - entries { - categories { index: 1 score: 0.8 } - categories { index: 2 score: 0.6 } - timestamp_ms: 0 - } - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.8 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.6899744811 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6456563062 - category_name: "great white shark" - } - categories { - index: 2 - score: 0.5986876601 - category_name: "goldfish" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor_0(kTwoHeadsNumClasses[0], 0); - tensor_0[1] = 0.2; - tensor_0[2] = 0.4; - tensor_0[3] = 0.6; - std::vector tensor_1(kTwoHeadsNumClasses[1], 0); - tensor_1[1] = 0.2; - tensor_1[2] = 0.4; - tensor_1[3] = 0.6; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kFloat32); - AddTensor(tensor_1, Tensor::ElementType::kFloat32); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Narration, monologue" - } - categories { - index: 2 - score: 0.4 - category_name: "Conversation" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "yamnet_classification" - } - classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Azara\'s Spinetail" - } - categories { - index: 2 - score: 0.4 - category_name: "House Sparrow" - } - timestamp_ms: 0 - } - head_index: 1 - head_name: "bird_classification" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, - /*connect_timestamps=*/true)); - // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); - tensor_0[1] = 12; - tensor_0[2] = 14; - tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); - tensor_1[5] = 12; - tensor_1[6] = 14; - tensor_1[7] = 16; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - AddTensor(tensor_1, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run( - /*aggregation_timestamps=*/std::optional>({0, 1000}), - /*timestamp=*/1000)); - - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - entries { - categories { index: 7 score: 0.6 category_name: "stingray" } - categories { - index: 6 - score: 0.4 - category_name: "electric ray" - } - timestamp_ms: 1 - } - head_index: 0 - head_name: "probability" - })pb")); -} - } // namespace } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 36ff68a07..9a7dce1aa 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -// TODO: remove once Java API migration is over. -// Struct holding the different output streams produced by the text classifier. -struct TextClassifierOutputStreams { - Source classification_result; - Source classifications; -}; - } // namespace // A "TextClassifierGraph" performs Natural Language classification (including @@ -72,10 +64,6 @@ struct TextClassifierOutputStreams { // Outputs: // CLASSIFICATIONS - ClassificationResult @Optional // The classification results aggregated by classifier head. -// TODO: remove once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result object that has 3 dimensions: -// (classification head, classification timestamp, classification category). // // Example: // node { @@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - auto output_streams, + auto classifications, BuildTextClassifierTask( sc->Options(), *model_resources, graph[Input(kTextTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; - output_streams.classifications >> - graph[Output(kClassificationsTag)]; + classifications >> graph[Output(kClassificationsTag)]; return graph.GetConfig(); } @@ -124,7 +109,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // TextClassifier model file with model metadata. // text_in: (std::string) stream to run text classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr BuildTextClassifierTask( + absl::StatusOr> BuildTextClassifierTask( const proto::TextClassifierGraphOptions& task_options, const ModelResources& model_resources, Source text_in, Graph& graph) { @@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. - return TextClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], - /*classifications=*/postprocessing[Output( - kClassificationsTag)]}; + return postprocessing[Output(kClassificationsTag)]; } }; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8fa1a0d2a..2fc88bcb6 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS"; // Struct holding the different output streams produced by the image classifier // subgraph. struct ImageClassifierOutputStreams { - Source classification_result; Source classifications; Source image; }; @@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams { // The classification results aggregated by classifier head. // IMAGE - Image // The image that object detection runs on. -// TODO: remove this output once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example: // node { @@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -174,8 +167,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. return ImageClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], /*classifications=*/ postprocessing[Output(kClassificationsTag)], /*image=*/preprocessing[Output(kImageTag)]}; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 0e72878ab..023a1f286 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -48,7 +48,6 @@ android_library( deps = [ "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 289e3000d..72cee133f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -97,7 +97,6 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d931c26c7..9d275e167 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -68,7 +68,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", + "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index cfdb83740..9b5419883 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,10 +16,10 @@ import dataclasses from typing import Any, Optional -from mediapipe.tasks.cc.components.containers.proto import category_pb2 +from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls -_CategoryProto = category_pb2.Category +_ClassificationProto = classification_pb2.Classification @dataclasses.dataclass @@ -45,23 +45,23 @@ class Category: category_name: Optional[str] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _CategoryProto: + def to_pb2(self) -> _ClassificationProto: """Generates a Category protobuf object.""" - return _CategoryProto( + return _ClassificationProto( index=self.index, score=self.score, - display_name=self.display_name, - category_name=self.category_name) + label=self.category_name, + display_name=self.display_name) @classmethod @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category': + def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category': """Creates a `Category` object from the given protobuf object.""" return Category( index=pb2_obj.index, score=pb2_obj.score, display_name=pb2_obj.display_name, - category_name=pb2_obj.category_name) + category_name=pb2_obj.label) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py index 6ffdabe51..000468041 100644 --- a/mediapipe/tasks/python/components/containers/classification_result.py +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -49,11 +49,7 @@ class Classifications: """Generates a Classifications protobuf object.""" classification_list_proto = _ClassificationListProto() for category in self.categories: - classification_proto = _ClassificationProto( - index=category.index, - score=category.score, - label=category.category_name, - display_name=category.display_name) + classification_proto = category.to_pb2() classification_list_proto.classification.append(classification_proto) return _ClassificationsProto( classification_list=classification_list_proto, @@ -65,14 +61,9 @@ class Classifications: def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': """Creates a `Classifications` object from the given protobuf object.""" categories = [] - for entry in pb2_obj.classification_list.classification: + for classification in pb2_obj.classification_list.classification: categories.append( - category_module.Category( - index=entry.index, - score=entry.score, - display_name=entry.display_name, - category_name=entry.label)) - + category_module.Category.create_from_pb2(classification)) return Classifications( categories=categories, head_index=pb2_obj.head_index, From 7f0134eecbe75a94bcda7cf113e1ae8aa47cd916 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 12:13:38 -0800 Subject: [PATCH 103/469] Internal change PiperOrigin-RevId: 490041386 --- mediapipe/tasks/python/core/BUILD | 1 + mediapipe/tasks/python/text/BUILD | 1 + 2 files changed, 2 insertions(+) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 76e2f4f4a..fc0018ab1 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -31,6 +31,7 @@ py_library( py_library( name = "base_options", srcs = ["base_options.py"], + visibility = ["//mediapipe/tasks:users"], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index bb42da912..10b4b8a6e 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -23,6 +23,7 @@ py_library( srcs = [ "text_classifier.py", ], + visibility = ["//mediapipe/tasks:users"], deps = [ "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", From 652423a23d9a69d5c3dabe61926a55bd77d6d610 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 13:04:53 -0800 Subject: [PATCH 104/469] Internal change PiperOrigin-RevId: 490053179 --- mediapipe/calculators/tensor/image_to_tensor_utils.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index d27c595b5..3f91f3dc2 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,11 +253,15 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // The output tensor channel is 1 for the input image with 1 channel; And the - // output tensor channels is 3 for the input image with 3 or 4 channels. // TODO: Add a unittest here to test the behavior on GPU, i.e. // failure. - return image.channels() == 1 ? 1 : 3; + // Only output channel == 1 when running on CPU and the input image channel + // is 1. Ideally, we want to also support GPU for output channel == 1. But + // setting this on the safer side to prevent unintentional failure. + if (!image.UsesGpu() && image.channels() == 1) { + return 1; + } + return 3; } absl::StatusOr> GetInputImage( From adddf2c2abe953b0280507b6168a41bcbb5a08f3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 14:37:42 -0800 Subject: [PATCH 105/469] Extracted common test helper functions out from the unittest into a sharable library. Also migrated away from OpenCVX. PiperOrigin-RevId: 490074410 --- mediapipe/calculators/tensor/BUILD | 2 + .../tensor/image_to_tensor_calculator_test.cc | 169 ++++++------------ mediapipe/util/BUILD | 18 ++ mediapipe/util/image_test_utils.cc | 57 ++++++ mediapipe/util/image_test_utils.h | 32 ++++ 5 files changed, 166 insertions(+), 112 deletions(-) create mode 100644 mediapipe/util/image_test_utils.cc create mode 100644 mediapipe/util/image_test_utils.h diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 2a573fc44..645189a07 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -30,6 +30,7 @@ exports_files( glob(["testdata/image_to_tensor/*"]), visibility = [ "//mediapipe/calculators/image:__subpackages__", + "//mediapipe/util:__subpackages__", ], ) @@ -1133,6 +1134,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/util:image_test_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 7ea60d98e..ceb1fc502 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -36,29 +36,17 @@ #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/image_test_utils.h" namespace mediapipe { namespace { -cv::Mat GetRgb(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); - return rgb; -} +constexpr char kTestDataDir[] = + "/mediapipe/calculators/tensor/testdata/" + "image_to_tensor/"; -cv::Mat GetRgba(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); - return rgb; -} - -cv::Mat GetGray(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat gray; - cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); - return gray; +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDir, filename); } // Image to tensor test template. @@ -259,15 +247,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, - /*border mode*/ {}, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + /*border mode*/ {}, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { @@ -277,11 +262,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -295,11 +277,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -314,11 +293,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -332,16 +309,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb( - "/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), - /*float_ranges=*/{{-1.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation.png")), + /*float_ranges=*/{{-1.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { @@ -351,11 +324,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation_border_zero.png")), /*float_ranges=*/{{-1.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, @@ -369,10 +339,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, @@ -386,15 +354,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, - BorderMode::kZero, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_border_zero.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, + BorderMode::kZero, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { @@ -404,15 +369,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { @@ -422,11 +384,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -440,11 +399,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -458,11 +414,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -477,11 +430,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -496,11 +447,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -514,10 +463,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -531,10 +478,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 15835aea5..55c1df59f 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -368,3 +368,21 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "image_test_utils", + testonly = 1, + srcs = ["image_test_utils.cc"], + hdrs = ["image_test_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + ], +) diff --git a/mediapipe/util/image_test_utils.cc b/mediapipe/util/image_test_utils.cc new file mode 100644 index 000000000..815666985 --- /dev/null +++ b/mediapipe/util/image_test_utils.cc @@ -0,0 +1,57 @@ +#include "mediapipe/util/image_test_utils.h" + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +cv::Mat GetRgb(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); + return rgb; +} + +cv::Mat GetRgba(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); + return rgb; +} + +cv::Mat GetGray(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + LOG(FATAL) << "Unsupported input image channles: " << image_channels; +} + +Packet MakeImageFramePacket(cv::Mat input, int timestamp) { + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +Packet MakeImagePacket(cv::Mat input, int timestamp) { + mediapipe::Image input_image(std::make_shared( + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +} // namespace mediapipe diff --git a/mediapipe/util/image_test_utils.h b/mediapipe/util/image_test_utils.h new file mode 100644 index 000000000..6df9644d2 --- /dev/null +++ b/mediapipe/util/image_test_utils.h @@ -0,0 +1,32 @@ +#ifndef MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ +#define MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ + +#include + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/opencv_core_inc.h" + +namespace mediapipe { + +// Reads the image file into cv::Mat with RGB channels. +cv::Mat GetRgb(const std::string& path); + +// Reads the image file into cv::Mat with RGBA channels. +cv::Mat GetRgba(const std::string& path); + +// Reads the image file into cv::Mat with Gray channel. +cv::Mat GetGray(const std::string& path); + +// Converts the image channels into corresponding ImageFormat. +mediapipe::ImageFormat::Format GetImageFormat(int image_channels); + +// Converts the cv::Mat into ImageFrame packet. +Packet MakeImageFramePacket(cv::Mat input, int timestamp = 0); + +// Converts the cv::Mat into Image packet. +Packet MakeImagePacket(cv::Mat input, int timestamp = 0); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ From d43d0ff615030abb9241c28e6de6e345a8dba7eb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 15:45:29 -0800 Subject: [PATCH 106/469] Internal change PiperOrigin-RevId: 490089940 --- .../image_to_tensor_converter_opencv.cc | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 76e46f99d..95e38f89c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -76,31 +76,49 @@ class OpenCvProcessor : public ImageToTensorConverter { return InvalidArgumentError(absl::StrCat( "Unsupported format: ", static_cast(input.image_format()))); } - // TODO: Remove the check once tensor_buffer_offset > 0 is - // supported. - RET_CHECK_EQ(tensor_buffer_offset, 0) - << "The non-zero tensor_buffer_offset input is not supported yet."; + + RET_CHECK_GE(tensor_buffer_offset, 0) + << "The input tensor_buffer_offset needs to be non-negative."; const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); const int output_height = output_shape.dims[1]; const int output_width = output_shape.dims[2]; const int output_channels = output_shape.dims[3]; + const int num_elements_per_img = + output_height * output_width * output_channels; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE(output_shape.num_elements(), + tensor_buffer_offset / sizeof(int8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(int8)); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(float) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(float)); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(uint8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(uint8)); break; default: return InvalidArgumentError( @@ -153,9 +171,8 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { RET_CHECK_EQ(output_shape.dims.size(), 4) << "Wrong output dims size: " << output_shape.dims.size(); - RET_CHECK_EQ(output_shape.dims[0], 1) - << "Handling batch dimension not equal to 1 is not implemented in this " - "converter."; + RET_CHECK_GE(output_shape.dims[0], 1) + << "The batch dimension needs to be equal or larger than 1."; RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); From 7c9fc9a6428b1c40738b5dce80abbacd627c4bdf Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 21 Nov 2022 21:45:58 -0800 Subject: [PATCH 107/469] Remove `mp.solutions` from doc generation. These need to be excluded from the current package, so do it automatically. PiperOrigin-RevId: 490146934 --- docs/build_py_api_docs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fa1e4314f..fe706acd3 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -30,7 +30,7 @@ from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. - import mediapipe # pytype: disable=import-error + import mediapipe as mp # pytype: disable=import-error except ImportError as e: raise ImportError('Please `pip install mediapipe`.') from e @@ -58,11 +58,13 @@ _SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', def gen_api_docs(): """Generates API docs for the mediapipe package.""" + if hasattr(mp, 'solutions'): + del mp.solutions doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[(PROJECT_SHORT_NAME, mediapipe)], - base_dir=os.path.dirname(mediapipe.__file__), + py_modules=[(PROJECT_SHORT_NAME, mp)], + base_dir=os.path.dirname(mp.__file__), code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, From 54a684717fa39cd39315f8f6cb60b6c5a7fa76aa Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:22:49 -0800 Subject: [PATCH 108/469] Internal change PiperOrigin-RevId: 490159674 --- mediapipe/gpu/attachments.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/attachments.h b/mediapipe/gpu/attachments.h index ca9f074c4..3a73e4676 100644 --- a/mediapipe/gpu/attachments.h +++ b/mediapipe/gpu/attachments.h @@ -31,8 +31,8 @@ class AttachmentBase {}; template class Attachment : public AttachmentBase { public: - using FactoryT = std::function(Context&)>; - Attachment(FactoryT factory) : factory_(factory) {} + using FactoryT = AttachmentPtr (*)(Context&); + explicit constexpr Attachment(FactoryT factory) : factory_(factory) {} Attachment(const Attachment&) = delete; Attachment(Attachment&&) = delete; From a8b776102240ecb73f1a7aeb8ace9db42eb05f96 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:27:55 -0800 Subject: [PATCH 109/469] Define a kUtilityFramebuffer context attachment A framebuffer object is often needed to render to a texture or read data from it. Currently we create one in each GlCalculatorHelper, but that is redundant (we only need one per context, and multiple calculators can share the same context). Other times, the code that needs to use this doesn't own a helper. For both reasons, this should be attached to the context. We could just make this a member of GlContext since it's so common. However, I figured we might as well use the attachment system. PiperOrigin-RevId: 490160214 --- mediapipe/gpu/gl_context.cc | 12 ++++++++++++ mediapipe/gpu/gl_context.h | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 53e3ff8b7..99b995dda 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -1054,4 +1054,16 @@ void GlContext::SetStandardTextureParams(GLenum target, GLint internal_format) { glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); } +const GlContext::Attachment kUtilityFramebuffer( + [](GlContext&) -> GlContext::Attachment::Ptr { + GLuint framebuffer; + glGenFramebuffers(1, &framebuffer); + if (!framebuffer) return nullptr; + return {new GLuint(framebuffer), [](void* ptr) { + GLuint* fb = static_cast(ptr); + glDeleteFramebuffers(1, fb); + delete fb; + }}; + }); + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 7f5168d8b..4f2390404 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -474,6 +474,12 @@ class GlContext : public std::enable_shared_from_this { bool destructing_ = false; }; +// A framebuffer that the framework can use to attach textures for rendering +// etc. +// This could just be a member of GlContext, but it serves as a basic example +// of an attachment. +ABSL_CONST_INIT extern const GlContext::Attachment kUtilityFramebuffer; + // For backward compatibility. TODO: migrate remaining callers. ABSL_DEPRECATED( "Prefer passing an explicit GlVersion argument (use " From bacbac8d926d769bf51f770914d603b942094ebb Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:57:33 -0800 Subject: [PATCH 110/469] Use kUtilityFramebuffer in ReadTexture This avoids creating a temporary framebuffer each time. PiperOrigin-RevId: 490163892 --- mediapipe/gpu/gl_texture_buffer.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 7f77cd4b3..3d2642552 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -15,6 +15,7 @@ #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" @@ -333,8 +334,8 @@ void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { #endif // __ANDROID__ } -static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, - void* output, size_t size) { +static void ReadTexture(GlContext& ctx, const GlTextureView& view, + GpuBufferFormat format, void* output, size_t size) { // TODO: check buffer size? We could use glReadnPixels where available // (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read // won't overflow the buffer with glReadPixels, we'd also need to check or @@ -347,10 +348,7 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, GLint previous_fbo; glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - // We use a temp fbo to avoid depending on the app having an existing one. - // TODO: keep a utility fbo around in the context? - GLuint fbo = 0; - glGenFramebuffers(1, &fbo); + GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), view.name(), 0); @@ -360,7 +358,6 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, 0); // TODO: just set the binding to 0 to avoid the get call? glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); - glDeleteFramebuffers(1, &fbo); } static std::shared_ptr ConvertToImageFrame( @@ -370,9 +367,10 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - buf->GetProducerContext()->Run([buf, &output] { + auto ctx = buf->GetProducerContext(); + ctx->Run([buf, &output, &ctx] { auto view = buf->GetReadView(internal::types{}, /*plane=*/0); - ReadTexture(view, buf->format(), output->MutablePixelData(), + ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); return std::make_shared(std::move(output)); From d648926155d19cb6665895661624ec19cc7d33c6 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 00:35:27 -0800 Subject: [PATCH 111/469] Just reset the fb binding to 0 in ReadTexture This saves a get operation. We already have precedent in lots of other MediaPipe code where we just reset bindings to 0. PiperOrigin-RevId: 490170691 --- mediapipe/gpu/gl_texture_buffer.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 3d2642552..d530d5d12 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -345,9 +345,6 @@ static void ReadTexture(GlContext& ctx, const GlTextureView& view, GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format, view.plane(), view.gl_context()->GetGlVersion()); - GLint previous_fbo; - glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), @@ -356,8 +353,7 @@ static void ReadTexture(GlContext& ctx, const GlTextureView& view, output); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, 0, 0); - // TODO: just set the binding to 0 to avoid the get call? - glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); + glBindFramebuffer(GL_FRAMEBUFFER, 0); } static std::shared_ptr ConvertToImageFrame( From 872d1afda7f8a465db59dfcf9ab56e6d60832646 Mon Sep 17 00:00:00 2001 From: vrabaud Date: Tue, 22 Nov 2022 03:10:35 -0800 Subject: [PATCH 112/469] Internal change PiperOrigin-RevId: 490196129 --- mediapipe/framework/port/BUILD | 11 ++++++++++ mediapipe/framework/port/opencv_videoio_inc.h | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 mediapipe/framework/port/opencv_videoio_inc.h diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 87944d80f..e499ca3a6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -311,6 +311,17 @@ cc_library( ], ) +cc_library( + name = "opencv_videoio", + hdrs = ["opencv_videoio_inc.h"], + visibility = ["//visibility:public"], + deps = [ + ":opencv_core", + "//mediapipe/framework:port", + "//third_party:opencv", + ], +) + cc_library( name = "parse_text_proto", hdrs = [ diff --git a/mediapipe/framework/port/opencv_videoio_inc.h b/mediapipe/framework/port/opencv_videoio_inc.h new file mode 100644 index 000000000..63029b69f --- /dev/null +++ b/mediapipe/framework/port/opencv_videoio_inc.h @@ -0,0 +1,21 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ +#define MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ + +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "third_party/OpenCV/videoio.hpp" + +#endif // MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ From 515d00fc22100bfb948aecfa39408a0b599a0c89 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 15:16:52 -0800 Subject: [PATCH 113/469] Internal change PiperOrigin-RevId: 490349260 --- mediapipe/framework/formats/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index e13bb2704..4276ffc3a 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -312,9 +312,7 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) mediapipe_register_type( From 7ce4aa6592c30c2ac5d0c075304e50ae7d01b38f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 16:38:51 -0800 Subject: [PATCH 114/469] Internal change PiperOrigin-RevId: 490366250 --- mediapipe/util/sequence/media_sequence_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 40a474599..42b0e3889 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -802,7 +802,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImages) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_image(bytes.begin(), bytes.end()); AddImageEncoded(encoded_image, &sequence); AddImageEncoded(encoded_image, &sequence); @@ -843,7 +843,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFlowEncoded) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_flow(bytes.begin(), bytes.end()); AddForwardFlowEncoded(encoded_flow, &sequence); From efa9e737f80e245aec4c6ef9483fc92547e6d1d9 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 17:22:18 -0800 Subject: [PATCH 115/469] Use current context if available in ConvertToImageFrame If we're already running in a GlContext, there's no need to go back to the producer context, which may be different. PiperOrigin-RevId: 490373829 --- mediapipe/gpu/gl_texture_buffer.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index d530d5d12..69b9889c7 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -363,7 +363,8 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - auto ctx = buf->GetProducerContext(); + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); ctx->Run([buf, &output, &ctx] { auto view = buf->GetReadView(internal::types{}, /*plane=*/0); ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), @@ -392,7 +393,9 @@ static std::shared_ptr ConvertToCvPixelBuffer( std::shared_ptr buf) { auto output = absl::make_unique( buf->width(), buf->height(), buf->format()); - buf->GetProducerContext()->Run([buf, &output] { + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); + ctx->Run([buf, &output] { TempGlFramebuffer framebuffer; auto src = buf->GetReadView(internal::types{}, /*plane=*/0); auto dst = From fac97554dfb80e8c14ecbfb2cbe12e0ad26ce0b4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 17:23:48 -0800 Subject: [PATCH 116/469] Small TS audio API improvement PiperOrigin-RevId: 490374083 --- .../audio_classifier/audio_classifier.ts | 14 +- .../audio/audio_embedder/audio_embedder.ts | 14 +- mediapipe/web/graph_runner/graph_runner.ts | 129 ++++++++++++++---- 3 files changed, 105 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 0c54a4718..20c745383 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -35,11 +35,7 @@ export * from './audio_classifier_result'; const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and -// cannot be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; @@ -154,14 +150,8 @@ export class AudioClassifier extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioClassifierResult[] { - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStream(audioData, timestampMs); + this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 51cb819de..46a7b6729 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -35,11 +35,7 @@ export * from './audio_embedder_result'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot -// be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const EMBEDDINGS_STREAM = 'embeddings_out'; const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; @@ -151,14 +147,8 @@ export class AudioEmbedder extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStream(audioData, timestampMs); + this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); this.embeddingResults = []; this.finishProcessing(); diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 7de5aa33b..c4654794c 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -15,9 +15,6 @@ export declare interface FileLocator { locateFile: (filename: string) => string; } -/** Listener to be passed in by user for handling output audio data. */ -export type AudioOutputListener = (output: Float32Array) => void; - /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -32,19 +29,14 @@ export declare interface WasmModule { _bindTextureToCanvas: () => boolean; _changeBinaryGraph: (size: number, dataPtr: number) => void; _changeTextGraph: (size: number, dataPtr: number) => void; - _configureAudio: - (channels: number, samples: number, sampleRate: number) => void; _free: (ptr: number) => void; _malloc: (size: number) => number; - _processAudio: (dataPtr: number, timestamp: number) => void; _processFrame: (width: number, height: number, timestamp: number) => void; _setAutoRenderToScreen: (enabled: boolean) => void; _waitUntilIdle: () => void; // Exposed so that clients of this lib can access this field dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; - // Wasm module will call us back at this function when given audio data. - onAudioOutput?: AudioOutputListener; // Wasm Module multistream entrypoints. Require // gl_graph_runner_internal_multi_input as a build dependency. @@ -100,11 +92,14 @@ export declare interface WasmModule { _attachProtoVectorListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; - // Requires dependency ":gl_graph_runner_audio_out", and will register an - // audio output listening function which can be tapped into dynamically during - // graph running via onAudioOutput. This call must be made before graph is - // initialized, but after wasmModule is instantiated. - _attachAudioOutputListener: () => void; + // Require dependency ":gl_graph_runner_audio_out" + _attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + + // Require dependency ":gl_graph_runner_audio" + _addAudioToInputStream: (dataPtr: number, numChannels: number, + numSamples: number, streamNamePtr: number, timestamp: number) => void; + _configureAudio: (channels: number, samples: number, sampleRate: number, + streamNamePtr: number, headerNamePtr: number) => void; // TODO: Refactor to just use a few numbers (perhaps refactor away // from gl_graph_runner_internal.cc entirely to use something a little more @@ -235,19 +230,38 @@ export class GraphRunner { } /** - * Configures the current graph to handle audio in a certain way. Must be - * called before the graph is set/started in order to use processAudio. + * Configures the current graph to handle audio processing in a certain way + * for all its audio input streams. Additionally can configure audio headers + * (both input side packets as well as input stream headers), but these + * configurations only take effect if called before the graph is set/started. * @param numChannels The number of channels of audio input. Only 1 * is supported for now. * @param numSamples The number of samples that are taken in each * audio capture. * @param sampleRate The rate, in Hz, of the sampling. + * @param streamName The optional name of the input stream to additionally + * configure with audio information. This configuration only occurs before + * the graph is set/started. If unset, a default stream name will be used. + * @param headerName The optional name of the header input side packet to + * additionally configure with audio information. This configuration only + * occurs before the graph is set/started. If unset, a default header name + * will be used. */ - configureAudio(numChannels: number, numSamples: number, sampleRate: number) { - this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); - if (this.wasmModule._attachAudioOutputListener) { - this.wasmModule._attachAudioOutputListener(); + configureAudio(numChannels: number, numSamples: number, sampleRate: number, + streamName?: string, headerName?: string) { + if (!this.wasmModule._configureAudio) { + console.warn( + 'Attempting to use configureAudio without support for input audio. ' + + 'Is build dep ":gl_graph_runner_audio" missing?'); } + streamName = streamName || 'input_audio'; + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + headerName = headerName || 'audio_header'; + this.wrapStringPtr(headerName, (headerNamePtr: number) => { + this.wasmModule._configureAudio(streamNamePtr, headerNamePtr, + numChannels, numSamples, sampleRate); + }); + }); } /** @@ -437,9 +451,36 @@ export class GraphRunner { * processed. * @param audioData An array of raw audio capture data, like * from a call to getChannelData on an AudioBuffer. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. * @param timestamp The timestamp of the current frame, in ms. */ - addAudioToStream(audioData: Float32Array, timestamp: number) { + addAudioToStream( + audioData: Float32Array, streamName: string, timestamp: number) { + // numChannels and numSamples being 0 will cause defaults to be used, + // which will reflect values from last call to configureAudio. + this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp); + } + + /** + * Takes the raw data from a JS audio capture array, and sends it to C++ to be + * processed, shaping the audioData array into an audio matrix according to + * the numChannels and numSamples parameters. + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param numChannels The number of audio channels this data represents. If 0 + * is passed, then the value will be taken from the last call to + * configureAudio. + * @param numSamples The number of audio samples captured in this data packet. + * If 0 is passed, then the value will be taken from the last call to + * configureAudio. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. + * @param timestamp The timestamp of the current frame, in ms. + */ + addAudioToStreamWithShape( + audioData: Float32Array, numChannels: number, numSamples: number, + streamName: string, timestamp: number) { // 4 bytes for each F32 const size = audioData.length * 4; if (this.audioSize !== size) { @@ -450,7 +491,11 @@ export class GraphRunner { this.audioSize = size; } this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); - this.wasmModule._processAudio(this.audioPtr!, timestamp); + + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addAudioToInputStream( + this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp); + }); } /** @@ -943,17 +988,45 @@ export class GraphRunner { } /** - * Sets a listener to be called back with audio output packet data, as a - * Float32Array, when graph has finished processing it. - * @param audioOutputListener The caller's listener function. + * Attaches an audio packet listener to the specified output_stream, to be + * given a Float32Array as output. + * @param outputStreamName The name of the graph output stream to grab audio + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. If the + * audio data needs to be able to outlive the call, you may set the + * optional makeDeepCopy parameter to true, or can manually deep-copy the + * data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). */ - setOnAudioOutput(audioOutputListener: AudioOutputListener) { - this.wasmModule.onAudioOutput = audioOutputListener; - if (!this.wasmModule._attachAudioOutputListener) { + attachAudioListener(outputStreamName: string, + callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void { + if (!this.wasmModule._attachAudioListener) { console.warn( - 'Attempting to use AudioOutputListener without support for ' + + 'Attempting to use attachAudioListener without support for ' + 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); } + + // Set up our TS listener to receive any packets for this stream, and + // additionally reformat our Uint8Array into a Float32Array for the user. + this.setListener(outputStreamName, (data: Uint8Array) => { + const floatArray = new Float32Array(data.buffer); // Should be very fast + callbackFcn(floatArray); + }); + + // Tell our graph to listen for string packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachAudioListener( + outputStreamNamePtr, makeDeepCopy || false); + }); } /** From 8ba9d87e667f0c6e67026f96aa58ee1a980b0ce1 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 17:25:55 -0800 Subject: [PATCH 117/469] Update ImageFrameToGpuBufferCalculator to use api2 and GpuBuffer conversions PiperOrigin-RevId: 490374387 --- mediapipe/gpu/BUILD | 2 + .../image_frame_to_gpu_buffer_calculator.cc | 62 ++++++++----------- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 10a8d7fff..f97eed678 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -901,6 +901,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", + ":gpu_buffer_storage_image_frame", + "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 2a8331db8..c67fb0c62 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -12,73 +12,63 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" -#ifdef __APPLE__ -#include "mediapipe/objc/util.h" -#endif - namespace mediapipe { +namespace api2 { -// Convert ImageFrame to GpuBuffer. -class ImageFrameToGpuBufferCalculator : public CalculatorBase { +class ImageFrameToGpuBufferCalculator + : public RegisteredNode { public: - ImageFrameToGpuBufferCalculator() {} + static constexpr Input kIn{""}; + static constexpr Output kOut{""}; - static absl::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; private: -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; -REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // static -absl::Status ImageFrameToGpuBufferCalculator::GetContract( +absl::Status ImageFrameToGpuBufferCalculator::UpdateContract( CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); // Note: we call this method even on platforms where we don't use the helper, // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. - MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); - return absl::OkStatus(); + return GlCalculatorHelper::UpdateContract(cc); } absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { - // Inform the framework that we always output at the same timestamp - // as we receive a packet at. - cc->SetOffset(TimestampDiff(0)); -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - CFHolder buffer; - MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( - cc->Inputs().Index(0).Value(), &buffer)); - cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); -#else - const auto& input = cc->Inputs().Index(0).Get(); - helper_.RunInGlContext([this, &input, &cc]() { - auto src = helper_.CreateSourceTexture(input); - auto output = src.GetFrame(); - glFlush(); - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - src.Release(); - }); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto image_frame = std::const_pointer_cast( + mediapipe::SharedPtrWithPacket(kIn(cc).packet())); + auto gpu_buffer = api2::MakePacket( + std::make_shared( + std::move(image_frame))) + .At(cc->InputTimestamp()); + // This calculator's behavior has been to do the texture upload eagerly, and + // some graphs may rely on running this on a separate GL context to avoid + // blocking another context with the read operation. So let's request GPU + // access here to ensure that the behavior stays the same. + // TODO: have a better way to do this, or defer until later. + helper_.RunInGlContext( + [&gpu_buffer] { auto view = gpu_buffer->GetReadView(0); }); + kOut(cc).Send(std::move(gpu_buffer)); return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe From 837225c53d55700ff485367bb0fa71890f905e2e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 17:30:23 -0800 Subject: [PATCH 118/469] Internal change PiperOrigin-RevId: 490374976 --- mediapipe/framework/validated_graph_config.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 16aad6e9b..01e3da83e 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -1048,6 +1048,14 @@ absl::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( for (const auto& required_item : required_side_packets_) { auto iter = side_packet_types.find(required_item.first); if (iter == side_packet_types.end()) { + bool is_optional = true; + for (int index : required_item.second) { + is_optional &= input_side_packets_[index].packet_type->IsOptional(); + } + if (is_optional) { + // Side packets that are optional and not provided are ignored. + continue; + } statuses.push_back(mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Side packet \"" << required_item.first << "\" is required but was not provided."); From 3bbc0e9af9150797142295f47b1d87a0403d8f44 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 22 Nov 2022 17:34:58 -0800 Subject: [PATCH 119/469] Internal change PiperOrigin-RevId: 490375672 --- mediapipe/tasks/web/BUILD | 18 +++--------------- mediapipe/tasks/web/audio.ts | 3 +-- mediapipe/tasks/web/text.ts | 3 +-- mediapipe/tasks/web/vision.ts | 6 +----- 4 files changed, 6 insertions(+), 24 deletions(-) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index af76a1fe8..7e5d02892 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -24,10 +24,7 @@ mediapipe_files(srcs = [ mediapipe_ts_library( name = "audio_lib", srcs = ["audio.ts"], - deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", - "//mediapipe/tasks/web/audio/audio_embedder", - ], + deps = ["//mediapipe/tasks/web/audio:audio_lib"], ) rollup_bundle( @@ -69,10 +66,7 @@ pkg_npm( mediapipe_ts_library( name = "text_lib", srcs = ["text.ts"], - deps = [ - "//mediapipe/tasks/web/text/text_classifier", - "//mediapipe/tasks/web/text/text_embedder", - ], + deps = ["//mediapipe/tasks/web/text:text_lib"], ) rollup_bundle( @@ -114,13 +108,7 @@ pkg_npm( mediapipe_ts_library( name = "vision_lib", srcs = ["vision.ts"], - deps = [ - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/hand_landmarker", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_embedder", - "//mediapipe/tasks/web/vision/object_detector", - ], + deps = ["//mediapipe/tasks/web/vision:vision_lib"], ) rollup_bundle( diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 056426f50..8c522efcc 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; -import {AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/audio_embedder/audio_embedder'; +import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index 39d101237..8f15075c5 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import {TextClassifier as TextClassifierImpl} from '../../tasks/web/text/text_classifier/text_classifier'; -import {TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/text_embedder/text_embedder'; +import {TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 4e4fab43f..74a056464 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,11 +14,7 @@ * limitations under the License. */ -import {GestureRecognizer as GestureRecognizerImpl} from '../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -import {HandLandmarker as HandLandmarkerImpl} from '../../tasks/web/vision/hand_landmarker/hand_landmarker'; -import {ImageClassifier as ImageClassifierImpl} from '../../tasks/web/vision/image_classifier/image_classifier'; -import {ImageEmbedder as ImageEmbedderImpl} from '../../tasks/web/vision/image_embedder/image_embedder'; -import {ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/object_detector/object_detector'; +import {GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. From a55839de51dafe27b4c2b705954444895a842c3c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 18:07:26 -0800 Subject: [PATCH 120/469] This storage only needs a "done writing" callback on simulator, so only set it there - When not on simulator, we pass nullptr instead of a do-nothing callback. - The callback is no longer a method, but a function. Only the CVPixelBuffer is captured. PiperOrigin-RevId: 490380248 --- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 45 +++++++++++-------- .../gpu/gpu_buffer_storage_cv_pixel_buffer.h | 1 - 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index f3954a6e4..014cc1c69 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -70,25 +70,9 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( return GetTexture(plane, nullptr); } -GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, int plane) { - return GetTexture(plane, [this](const mediapipe::GlTextureView& view) { - ViewDoneWriting(view); - }); -} - -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types) const { - return CreateImageFrameForCVPixelBuffer(**this); -} -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types) { - return CreateImageFrameForCVPixelBuffer(**this); -} - -void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { #if TARGET_IPHONE_SIMULATOR - CVPixelBufferRef pixel_buffer = **this; +static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, + const GlTextureView& view) { CHECK(pixel_buffer); CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) @@ -126,7 +110,30 @@ void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) << "CVPixelBufferUnlockBaseAddress failed: " << err; -#endif +} +#endif // TARGET_IPHONE_SIMULATOR + +GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types, int plane) { + return GetTexture(plane, +#if TARGET_IPHONE_SIMULATOR + [pixel_buffer = CFHolder(*this)]( + const mediapipe::GlTextureView& view) { + ViewDoneWritingSimulatorWorkaround(*pixel_buffer, view); + } +#else + nullptr +#endif // TARGET_IPHONE_SIMULATOR + ); +} + +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types) const { + return CreateImageFrameForCVPixelBuffer(**this); +} +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types) { + return CreateImageFrameForCVPixelBuffer(**this); } static std::shared_ptr ConvertFromImageFrame( diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index a9389ab8a..8723a1087 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -63,7 +63,6 @@ class GpuBufferStorageCvPixelBuffer private: GlTextureView GetTexture(int plane, GlTextureView::DoneWritingFn done_writing) const; - void ViewDoneWriting(const GlTextureView& view); }; inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( From 05681fc0e17089a4e1d3f999bd17f3020cabb9bc Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Nov 2022 01:26:15 -0800 Subject: [PATCH 121/469] Internal PiperOrigin-RevId: 490439195 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 8b09260bd..762184842 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -18,7 +18,6 @@ load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build load("@build_bazel_rules_android//android:rules.bzl", "android_library") _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", From c5ce5236972a6045f42bb23d526ebb27a7e58bb7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Nov 2022 02:02:18 -0800 Subject: [PATCH 122/469] Add cosine APIs to Embedder tasks PiperOrigin-RevId: 490444597 --- .../tasks/web/audio/audio_embedder/BUILD | 1 + .../audio/audio_embedder/audio_embedder.ts | 15 +++++ mediapipe/tasks/web/components/utils/BUILD | 11 ++++ .../web/components/utils/cosine_similarity.ts | 62 +++++++++++++++++++ mediapipe/tasks/web/text/text_embedder/BUILD | 1 + .../web/text/text_embedder/text_embedder.ts | 15 +++++ .../tasks/web/vision/image_embedder/BUILD | 1 + .../vision/image_embedder/image_embedder.ts | 15 +++++ 8 files changed, 121 insertions(+) create mode 100644 mediapipe/tasks/web/components/utils/BUILD create mode 100644 mediapipe/tasks/web/components/utils/cosine_similarity.ts diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 7d9a994a3..1a66464bd 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 46a7b6729..9dce02862 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -20,8 +20,10 @@ import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../.. import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -144,6 +146,19 @@ export class AudioEmbedder extends AudioTaskRunner { return this.processAudioClip(audioData, sampleRate); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD new file mode 100644 index 000000000..1c1ba69ca --- /dev/null +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -0,0 +1,11 @@ +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts new file mode 100644 index 000000000..fb1d0c185 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -0,0 +1,62 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +/** + * Computes cosine similarity[1] between two `Embedding` objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types (float vs. quantized), + * have different sizes, or have an L2-norm of 0. + */ +export function computeCosineSimilarity(u: Embedding, v: Embedding): number { + if (u.floatEmbedding && v.floatEmbedding) { + return compute(u.floatEmbedding, v.floatEmbedding); + } + if (u.quantizedEmbedding && v.quantizedEmbedding) { + return compute( + convertToBytes(u.quantizedEmbedding), + convertToBytes(v.quantizedEmbedding)); + } + throw new Error( + 'Cannot compute cosine similarity between quantized and float embeddings.'); +} +function convertToBytes(data: Uint8Array): number[] { + return Array.from(data, v => v - 128); +} + +function compute(u: number[], v: number[]) { + if (u.length !== v.length) { + throw new Error( + `Cannot compute cosine similarity between embeddings of different sizes (${ + u.length} vs. ${v.length}).`); + } + let dotProduct = 0.0; + let normU = 0.0; + let normV = 0.0; + for (let i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new Error( + 'Cannot compute cosine similarity on embedding with 0 norm.'); + } + return dotProduct / Math.sqrt(normU * normV); +} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index c555f8d33..3f92b8ae1 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 57b91d575..2042a0985 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -18,9 +18,11 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; @@ -143,6 +145,19 @@ export class TextEmbedder extends TaskRunner { return this.embeddingResult; } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Updates the MediaPipe graph configuration. */ private refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index feb3ae054..2f012dc5e 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -21,6 +21,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/vision/core:vision_task_options", diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index c60665052..f96f1e961 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -19,8 +19,10 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; @@ -157,6 +159,19 @@ export class ImageEmbedder extends VisionTaskRunner { return this.processVideoData(imageFrame, timestamp); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Runs the embedding extraction and blocks on the response. */ protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { From b5189758f7fc913e050ae0e6d4f7f999365e8118 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 02:03:35 -0800 Subject: [PATCH 123/469] Move ImagePreprocessing to "processors" folder. PiperOrigin-RevId: 490444821 --- mediapipe/tasks/cc/components/BUILD | 45 --- .../tasks/cc/components/processors/BUILD | 33 ++ .../image_preprocessing_graph.cc} | 42 ++- .../image_preprocessing_graph.h} | 26 +- .../image_preprocessing_graph_test.cc | 343 ++++++++++++++++++ .../cc/components/processors/proto/BUILD | 10 + .../image_preprocessing_graph_options.proto} | 6 +- .../tasks/cc/vision/gesture_recognizer/BUILD | 4 - .../gesture_recognizer/gesture_recognizer.cc | 1 - .../hand_gesture_recognizer_graph.cc | 2 - mediapipe/tasks/cc/vision/hand_detector/BUILD | 2 +- .../hand_detector/hand_detector_graph.cc | 20 +- .../tasks/cc/vision/hand_landmarker/BUILD | 3 +- .../vision/hand_landmarker/hand_landmarker.cc | 1 - .../hand_landmarks_detector_graph.cc | 17 +- .../tasks/cc/vision/image_classifier/BUILD | 4 +- .../image_classifier_graph.cc | 19 +- .../tasks/cc/vision/image_embedder/BUILD | 4 +- .../image_embedder/image_embedder_graph.cc | 19 +- .../tasks/cc/vision/image_segmenter/BUILD | 4 +- .../image_segmenter/image_segmenter_graph.cc | 19 +- .../tasks/cc/vision/object_detector/BUILD | 2 +- .../object_detector/object_detector_graph.cc | 17 +- 23 files changed, 493 insertions(+), 150 deletions(-) rename mediapipe/tasks/cc/components/{image_preprocessing.cc => processors/image_preprocessing_graph.cc} (90%) rename mediapipe/tasks/cc/components/{image_preprocessing.h => processors/image_preprocessing_graph.h} (72%) create mode 100644 mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc rename mediapipe/tasks/cc/components/{image_preprocessing_options.proto => processors/proto/image_preprocessing_graph_options.proto} (89%) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index c90349ab2..54a5207d2 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -12,55 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "image_preprocessing_options_proto", - srcs = ["image_preprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) - -cc_library( - name = "image_preprocessing", - srcs = ["image_preprocessing.cc"], - hdrs = ["image_preprocessing.h"], - deps = [ - ":image_preprocessing_options_cc_proto", - "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/calculators/image:image_clone_calculator", - "//mediapipe/calculators/image:image_clone_calculator_cc_proto", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", - "//mediapipe/gpu:gpu_origin_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - -# TODO: Enable this test - # TODO: Investigate rewriting the build rule to only link # the Bert Preprocessor if it's needed. cc_library( diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 32a628db7..4946683f5 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -100,3 +100,36 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "image_preprocessing_graph", + srcs = ["image_preprocessing_graph.cc"], + hdrs = ["image_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_clone_calculator", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc similarity index 90% rename from mediapipe/tasks/cc/components/image_preprocessing.cc rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index ef447df97..b24b7f0cb 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include #include @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" @@ -42,6 +42,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::Tensor; @@ -144,9 +145,9 @@ bool DetermineImagePreprocessingGpuBackend( return acceleration.has_gpu(); } -absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, - bool use_gpu, - ImagePreprocessingOptions* options) { +absl::Status ConfigureImagePreprocessingGraph( + const ModelResources& model_resources, bool use_gpu, + proto::ImagePreprocessingGraphOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator( @@ -154,9 +155,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { - options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::GPU_BACKEND); } else { - options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::CPU_BACKEND); } return absl::OkStatus(); } @@ -170,8 +171,7 @@ Source AddDataConverter(Source image_in, Graph& graph, return image_converter[Output("")]; } -// A "mediapipe.tasks.components.ImagePreprocessingSubgraph" performs image -// preprocessing. +// An ImagePreprocessingGraph performs image preprocessing. // - Accepts CPU input images and outputs CPU tensors. // // Inputs: @@ -192,7 +192,7 @@ Source AddDataConverter(Source image_in, Graph& graph, // An std::array representing the letterbox padding from the 4 // sides ([left, top, right, bottom]) of the output image, normalized to // [0.f, 1.f] by the output dimensions. The padding values are non-zero only -// when the "keep_aspect_ratio" is true in ImagePreprocessingOptions. +// when the "keep_aspect_ratio" is true in ImagePreprocessingGraphOptions. // IMAGE_SIZE - std::pair @Optional // The size of the original input image as a pair. // IMAGE - Image @Optional @@ -200,15 +200,15 @@ Source AddDataConverter(Source image_in, Graph& graph, // GPU). // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureImagePreprocessing()' function. See header file for more -// details. -class ImagePreprocessingSubgraph : public Subgraph { +// using the 'ConfigureImagePreprocessingGraph()' function. See header file for +// more details. +class ImagePreprocessingGraph : public Subgraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; auto output_streams = BuildImagePreprocessing( - sc->Options(), + sc->Options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph); output_streams.tensors >> graph[Output>(kTensorsTag)]; @@ -233,24 +233,25 @@ class ImagePreprocessingSubgraph : public Subgraph { // - the image that has pixel data stored on the target storage // (mediapipe::Image). // - // options: the mediapipe tasks ImagePreprocessingOptions. + // options: the mediapipe tasks ImagePreprocessingGraphOptions. // image_in: (mediapipe::Image) stream to preprocess. // graph: the mediapipe builder::Graph instance to be updated. ImagePreprocessingOutputStreams BuildImagePreprocessing( - const ImagePreprocessingOptions& options, Source image_in, - Source norm_rect_in, Graph& graph) { + const proto::ImagePreprocessingGraphOptions& options, + Source image_in, Source norm_rect_in, + Graph& graph) { // Convert image to tensor. auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); image_to_tensor.GetOptions() .CopyFrom(options.image_to_tensor_options()); switch (options.backend()) { - case ImagePreprocessingOptions::CPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::CPU_BACKEND: { auto cpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/false); cpu_image >> image_to_tensor.In(kImageTag); break; } - case ImagePreprocessingOptions::GPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::GPU_BACKEND: { auto gpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/true); gpu_image >> image_to_tensor.In(kImageTag); @@ -284,8 +285,9 @@ class ImagePreprocessingSubgraph : public Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ImagePreprocessingSubgraph); + ::mediapipe::tasks::components::processors::ImagePreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h similarity index 72% rename from mediapipe/tasks/cc/components/image_preprocessing.h rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h index 6963b6556..455a9b316 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h @@ -13,35 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures an ImagePreprocessing subgraph using the provided model resources +// Configures an ImagePreprocessingGraph using the provided model resources // When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.ImagePreprocessingGraph"); // core::proto::Acceleration acceleration; // acceleration.mutable_xnnpack(); // bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); -// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( +// MP_RETURN_IF_ERROR(ConfigureImagePreprocessingGraph( // model_resources, // use_gpu, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ImagePreprocessing subgraph has the following I/O: +// The resulting ImagePreprocessingGraph has the following I/O: // Inputs: // IMAGE - Image // The image to preprocess. @@ -61,17 +62,18 @@ namespace components { // IMAGE - Image @Optional // The image that has the pixel data stored on the target storage (CPU vs // GPU). -absl::Status ConfigureImagePreprocessing( +absl::Status ConfigureImagePreprocessingGraph( const core::ModelResources& model_resources, bool use_gpu, - ImagePreprocessingOptions* options); + proto::ImagePreprocessingGraphOptions* options); -// Determine if the image preprocessing subgraph should use GPU as the backend +// Determine if the image preprocessing graph should use GPU as the backend // according to the given acceleration setting. bool DetermineImagePreprocessingGpuBackend( const core::proto::Acceleration& acceleration); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc new file mode 100644 index 000000000..6c094c6bc --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc @@ -0,0 +1,343 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::testing::ContainerEq; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; +constexpr char kMobileNetFloatWithoutMetadata[] = + "mobilenet_v1_0.25_224_1_default_1.tflite"; +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kMobileNetQuantizedWithoutMetadata[] = + "mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; + +constexpr char kTestImage[] = "burger.jpg"; +constexpr int kTestImageWidth = 480; +constexpr int kTestImageHeight = 325; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; +constexpr std::array kIdentityMatrix = {1, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 1}; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kMatrixName[] = "matrix_out"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTensorsName[] = "tensors_out"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kImageSizeName[] = "image_size_out"; +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kLetterboxPaddingName[] = "letterbox_padding_out"; + +constexpr float kLetterboxMaxAbsError = 1e-5; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +// Helper function to create a TaskRunner from ModelResources. +absl::StatusOr> CreateTaskRunner( + const ModelResources& model_resources, bool keep_aspect_ratio) { + Graph graph; + + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + auto& options = + preprocessing.GetOptions(); + options.mutable_image_to_tensor_options()->set_keep_aspect_ratio( + keep_aspect_ratio); + MP_RETURN_IF_ERROR( + ConfigureImagePreprocessingGraph(model_resources, false, &options)); + graph[Input(kImageTag)].SetName(kImageName) >> + preprocessing.In(kImageTag); + preprocessing.Out(kTensorsTag).SetName(kTensorsName) >> + graph[Output>(kTensorsTag)]; + preprocessing.Out(kMatrixTag).SetName(kMatrixName) >> + graph[Output>(kMatrixTag)]; + preprocessing.Out(kImageSizeTag).SetName(kImageSizeName) >> + graph[Output>(kImageSizeTag)]; + preprocessing.Out(kLetterboxPaddingTag).SetName(kLetterboxPaddingName) >> + graph[Output>(kLetterboxPaddingTag)]; + + return TaskRunner::Create(graph.GetConfig()); +} + +class ConfigureTest : public tflite_shims::testing::Test {}; + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 192 + output_tensor_height: 192 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelFallbacksCpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelGpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: GPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + auto status = + ConfigureImagePreprocessingGraph(*model_resources, false, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_THAT(status.message(), + HasSubstr("requires specifying NormalizationOptions metadata")); +} + +// Struct holding the parameters for parameterized PreprocessingTest class. +struct PreprocessingParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // If true, keep test image aspect ratio. + bool keep_aspect_ratio; + // The expected output tensor type. + Tensor::ElementType expected_type; + // The expected outoput tensor shape. + std::vector expected_shape; + // The expected output letterbox padding; + std::array expected_letterbox_padding; +}; + +class PreprocessingTest : public testing::TestWithParam {}; + +TEST_P(PreprocessingTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kTestImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(GetParam().input_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, + CreateTaskRunner(*model_resources, GetParam().keep_aspect_ratio)); + + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + MP_ASSERT_OK(output_packets); + + const std::vector& tensors = + (*output_packets)[kTensorsName].Get>(); + EXPECT_EQ(tensors.size(), 1); + EXPECT_EQ(tensors[0].element_type(), GetParam().expected_type); + EXPECT_THAT(tensors[0].shape().dims, ContainerEq(GetParam().expected_shape)); + auto& matrix = (*output_packets)[kMatrixName].Get>(); + if (!GetParam().keep_aspect_ratio) { + for (int i = 0; i < matrix.size(); ++i) { + EXPECT_FLOAT_EQ(matrix[i], kIdentityMatrix[i]); + } + } + auto& image_size = + (*output_packets)[kImageSizeName].Get>(); + EXPECT_EQ(image_size.first, kTestImageWidth); + EXPECT_EQ(image_size.second, kTestImageHeight); + std::array letterbox_padding = + (*output_packets)[kLetterboxPaddingName].Get>(); + for (int i = 0; i < letterbox_padding.size(); ++i) { + EXPECT_NEAR(letterbox_padding[i], GetParam().expected_letterbox_padding[i], + kLetterboxMaxAbsError); + } +} + +INSTANTIATE_TEST_SUITE_P( + PreprocessingTest, PreprocessingTest, + Values( + PreprocessingParams{.test_name = "kMobileNetQuantizedWithMetadata", + .input_model_name = kMobileNetQuantizedWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetQuantizedWithoutMetadata", + .input_model_name = kMobileNetQuantizedWithoutMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 192, 192, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{.test_name = "kMobileNetFloatWithMetadata", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetFloatWithMetadataKeepAspectRatio", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = true, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {/*left*/ 0, + /*top*/ 0.161458, + /*right*/ 0, + /*bottom*/ 0.161458}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 23ebbe008..9c58a8585 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -49,3 +49,13 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", ], ) + +mediapipe_proto_library( + name = "image_preprocessing_graph_options_proto", + srcs = ["image_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/image_preprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto similarity index 89% rename from mediapipe/tasks/cc/components/image_preprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto index d1685c319..bf4fc9067 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto @@ -15,14 +15,14 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/image_to_tensor_calculator.proto"; import "mediapipe/framework/calculator.proto"; -message ImagePreprocessingOptions { +message ImagePreprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ImagePreprocessingOptions ext = 456882436; + optional ImagePreprocessingGraphOptions ext = 456882436; } // Options for the ImageToTensor calculator encapsulated by the diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 7b144e7aa..d473a8dc3 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", @@ -105,10 +104,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_cache", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 8d555b12c..e7fcf6fd9 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 7b6a8c79d..d7e983d81 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -29,8 +29,6 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" -#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 71cef6270..55162d09b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -46,7 +46,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 06bb2e549..c24548c9b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -226,21 +226,23 @@ class HandDetectorGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio // unchanged. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); auto& image_to_tensor_options = *preprocessing - .GetOptions() + .GetOptions() .mutable_image_to_tensor_options(); image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); image_in >> preprocessing.In("IMAGE"); norm_rect_in >> preprocessing.In("NORM_RECT"); auto preprocessed_tensors = preprocessing.Out("TENSORS"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 3b869eab4..46948ee6c 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -35,7 +35,6 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", @@ -89,7 +88,7 @@ cc_library( "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/utils:gate", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 3a9ed5bc2..2b818b2e5 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 1f127deb8..014830ba2 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -281,14 +281,15 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In("IMAGE"); hand_rect >> preprocessing.In("NORM_RECT"); auto image_size = preprocessing[Output>("IMAGE_SIZE")]; diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index 2b93aa262..514e601ef 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -59,11 +59,11 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 2fc88bcb6..2d0379c66 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -23,10 +23,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -135,14 +135,15 @@ class ImageClassifierGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index 8fdb97ccd..d729eaf1a 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -57,12 +57,12 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index bf0dcf3c7..81ccb5361 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -20,10 +20,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -130,14 +130,15 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 595eef568..2124fe6e0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -56,10 +56,10 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 44742e043..d5eb5af0d 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -27,8 +27,8 @@ limitations under the License. #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -243,14 +243,15 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index b8002fa96..c2dd9995d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -71,9 +71,9 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index b149cea0f..f5dc7e061 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -561,14 +561,15 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); From 3c53ec2cdbe5df2aabf6a20f3b6c9b4efa76cb71 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 10:09:42 -0800 Subject: [PATCH 124/469] Do not expose DrishtiGraphGPUData.h in public header This class is an implementation detail. PiperOrigin-RevId: 490530823 --- mediapipe/gpu/BUILD | 7 +------ mediapipe/gpu/MPPMetalHelper.h | 24 +++++++++++------------- mediapipe/gpu/MPPMetalHelper.mm | 6 ++++++ mediapipe/objc/MPPGraph.mm | 1 - 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index f97eed678..42cd9cdc6 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -550,12 +550,7 @@ cc_library( name = "gpu_shared_data_header", textual_hdrs = [ "gpu_shared_data_internal.h", - ] + select({ - "//conditions:default": [], - "//mediapipe:apple": [ - "MPPGraphGPUData.h", - ], - }), + ], visibility = ["//visibility:private"], deps = [ ":gl_base", diff --git a/mediapipe/gpu/MPPMetalHelper.h b/mediapipe/gpu/MPPMetalHelper.h index f3662422e..6ae0f3cf9 100644 --- a/mediapipe/gpu/MPPMetalHelper.h +++ b/mediapipe/gpu/MPPMetalHelper.h @@ -21,37 +21,35 @@ #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" NS_ASSUME_NONNULL_BEGIN @interface MPPMetalHelper : NSObject { - MPPGraphGPUData* _gpuShared; } - (instancetype)init NS_UNAVAILABLE; /// Initialize. This initializer is recommended for calculators. -- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc; +- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext *)cc; /// Initialize. -- (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources +- (instancetype)initWithGpuResources:(mediapipe::GpuResources *)gpuResources NS_DESIGNATED_INITIALIZER; /// Configures a calculator's contract for accessing GPU resources. /// Calculators should use this in GetContract. -+ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc; ++ (absl::Status)updateContract:(mediapipe::CalculatorContract *)cc; /// Deprecated initializer. -- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets; +- (instancetype)initWithSidePackets:(const mediapipe::PacketSet &)inputSidePackets; /// Deprecated initializer. -- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData*)gpuShared; +- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData *)gpuShared; /// Configures a calculator's side packets for accessing GPU resources. /// Calculators should use this in FillExpectations. -+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; ++ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet *)inputSidePackets; /// Get a metal command buffer. /// Calculators should use this method instead of getting a buffer from the @@ -63,23 +61,23 @@ NS_ASSUME_NONNULL_BEGIN /// Creates a CVMetalTextureRef linked to the provided GpuBuffer. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Creates a CVMetalTextureRef linked to the provided GpuBuffer given a specific plane. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Returns a MTLTexture linked to the provided GpuBuffer. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Returns a MTLTexture linked to the provided GpuBuffer given a specific plane. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Obtains a new GpuBuffer to be used as an output destination. @@ -91,7 +89,7 @@ NS_ASSUME_NONNULL_BEGIN format:(mediapipe::GpuBufferFormat)format; /// Convenience method to load a Metal library stored as a bundle resource. -- (id)newLibraryWithResourceName:(NSString*)name error:(NSError* _Nullable*)error; +- (id)newLibraryWithResourceName:(NSString *)name error:(NSError *_Nullable *)error; /// Shared Metal resources. @property(readonly) id mtlDevice; diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index ce6620972..dc1e27a5c 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,11 +14,17 @@ #import "mediapipe/gpu/MPPMetalHelper.h" +#import "mediapipe/gpu/MPPGraphGPUData.h" #import "mediapipe/gpu/graph_support.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" +@interface MPPMetalHelper () { + MPPGraphGPUData* _gpuShared; +} +@end + namespace mediapipe { // Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 080cca20f..1bd177e80 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -24,7 +24,6 @@ #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/graph_service.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/objc/util.h" From 54d1744c8f5ee102679386b84e3e3812e352bc7a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 10:13:48 -0800 Subject: [PATCH 125/469] Remove DrishtiGraphGPUData, add MetalSharedResources This class is unused except by the Metal helper; let's narrow it down and simplify gpu_shared_data. PiperOrigin-RevId: 490531767 --- mediapipe/gpu/BUILD | 50 +++------ mediapipe/gpu/MPPGraphGPUData.h | 71 ------------- mediapipe/gpu/MPPGraphGPUData.mm | 124 ---------------------- mediapipe/gpu/MPPGraphGPUDataTests.mm | 86 --------------- mediapipe/gpu/MPPMetalHelper.mm | 31 +++--- mediapipe/gpu/gpu_shared_data_internal.cc | 13 +-- mediapipe/gpu/gpu_shared_data_internal.h | 18 ++-- mediapipe/objc/BUILD | 2 +- 8 files changed, 46 insertions(+), 349 deletions(-) delete mode 100644 mediapipe/gpu/MPPGraphGPUData.h delete mode 100644 mediapipe/gpu/MPPGraphGPUData.mm delete mode 100644 mediapipe/gpu/MPPGraphGPUDataTests.mm diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 42cd9cdc6..9cc670fb6 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -470,12 +470,9 @@ objc_library( ) objc_library( - name = "MPPGraphGPUData", - srcs = [ - "MPPGraphGPUData.mm", - "gpu_shared_data_internal.cc", - ], - hdrs = ["MPPGraphGPUData.h"], + name = "metal_shared_resources", + srcs = ["metal_shared_resources.mm"], + hdrs = ["metal_shared_resources.h"], copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", @@ -484,25 +481,9 @@ objc_library( sdk_frameworks = [ "CoreVideo", "Metal", - ] + select({ - "//conditions:default": [ - "OpenGLES", - ], - "//mediapipe:macos": [ - "OpenGL", - "AppKit", - ], - }), + ], visibility = ["//visibility:public"], deps = [ - ":gl_base", - ":gl_context", - ":gpu_buffer_multi_pool", - ":gpu_shared_data_header", - ":graph_support", - ":cv_texture_cache_manager", - "//mediapipe/gpu:gl_context_options_cc_proto", - "//mediapipe/framework:calculator_context", "//mediapipe/framework/port:ret_check", "@google_toolbox_for_mac//:GTM_Defines", ] + [ @@ -584,16 +565,19 @@ cc_library( cc_library( name = "gpu_shared_data_internal_actual", - srcs = select({ - "//conditions:default": [ - "gpu_shared_data_internal.cc", - ], - # iOS uses an Objective-C++ version of this, built in MPPGraphGPUData. - "//mediapipe:apple": [], - }), + srcs = [ + "gpu_shared_data_internal.cc", + ], hdrs = [ "gpu_shared_data_internal.h", ], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + }), visibility = ["//visibility:private"], deps = [ "//mediapipe/gpu:gl_context_options_cc_proto", @@ -610,7 +594,7 @@ cc_library( ] + select({ "//conditions:default": [], "//mediapipe:apple": [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":cv_texture_cache_manager", ], }), @@ -1139,8 +1123,8 @@ objc_library( name = "gl_ios_test_lib", testonly = 1, srcs = [ - "MPPGraphGPUDataTests.mm", "gl_ios_test.mm", + "metal_shared_resources_test.mm", ], copts = [ "-Wno-shorten-64-to-32", @@ -1150,7 +1134,7 @@ objc_library( ], features = ["-layering_check"], deps = [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":gl_scaler_calculator", ":gpu_buffer_to_image_frame_calculator", ":gpu_shared_data_internal", diff --git a/mediapipe/gpu/MPPGraphGPUData.h b/mediapipe/gpu/MPPGraphGPUData.h deleted file mode 100644 index 3d8fc0c94..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ -#define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ - -#import -#import -#import - -#import "mediapipe/gpu/gl_base.h" -#import "mediapipe/gpu/gl_context.h" - -namespace mediapipe { -class GlContext; -class GpuBufferMultiPool; -} // namespace mediapipe - -@interface MPPGraphGPUData : NSObject { - // Shared buffer pool for GPU calculators. - mediapipe::GpuBufferMultiPool* _gpuBufferPool; - mediapipe::GlContext* _glContext; -} - -- (instancetype)init NS_UNAVAILABLE; - -/// Initialize. The provided multipool pointer must remain valid throughout -/// this object's lifetime. -- (instancetype)initWithContext:(mediapipe::GlContext*)context - multiPool:(mediapipe::GpuBufferMultiPool*)pool NS_DESIGNATED_INITIALIZER; - -/// Shared texture pool for GPU calculators. -/// For internal use by GlCalculatorHelper. -@property(readonly) mediapipe::GpuBufferMultiPool* gpuBufferPool; - -/// Shared OpenGL context. -#if TARGET_OS_OSX -@property(readonly) NSOpenGLContext* glContext; -@property(readonly) NSOpenGLPixelFormat* glPixelFormat; -#else -@property(readonly) EAGLContext* glContext; -#endif // TARGET_OS_OSX - -/// Shared texture cache. -#if TARGET_OS_OSX -@property(readonly) CVOpenGLTextureCacheRef textureCache; -#else -@property(readonly) CVOpenGLESTextureCacheRef textureCache; -#endif // TARGET_OS_OSX - -/// Shared Metal resources. -@property(readonly) id mtlDevice; -@property(readonly) id mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@property(readonly) CVMetalTextureCacheRef mtlTextureCache; -#endif - -@end - -#endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm deleted file mode 100644 index 8ac1eefa5..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.mm +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/gpu/MPPGraphGPUData.h" - -#import "GTMDefines.h" - -#include "mediapipe/gpu/gl_context.h" -#include "mediapipe/gpu/gpu_buffer_multi_pool.h" - -#if TARGET_OS_OSX -#import -#else -#import -#endif // TARGET_OS_OSX - -@implementation MPPGraphGPUData - -@synthesize textureCache = _textureCache; -@synthesize mtlDevice = _mtlDevice; -@synthesize mtlCommandQueue = _mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@synthesize mtlTextureCache = _mtlTextureCache; -#endif - -#if TARGET_OS_OSX -typedef CVOpenGLTextureCacheRef CVTextureCacheType; -#else -typedef CVOpenGLESTextureCacheRef CVTextureCacheType; -#endif // TARGET_OS_OSX - -- (instancetype)initWithContext:(mediapipe::GlContext *)context - multiPool:(mediapipe::GpuBufferMultiPool *)pool { - self = [super init]; - if (self) { - _gpuBufferPool = pool; - _glContext = context; - } - return self; -} - -- (void)dealloc { - if (_textureCache) { - _textureCache = NULL; - } -#if COREVIDEO_SUPPORTS_METAL - if (_mtlTextureCache) { - CFRelease(_mtlTextureCache); - _mtlTextureCache = NULL; - } -#endif -} - -#if TARGET_OS_OSX -- (NSOpenGLContext *)glContext { - return _glContext->nsgl_context(); -} - -- (NSOpenGLPixelFormat *) glPixelFormat { - return _glContext->nsgl_pixel_format(); -} -#else -- (EAGLContext *)glContext { - return _glContext->eagl_context(); -} -#endif // TARGET_OS_OSX - -- (CVTextureCacheType)textureCache { - @synchronized(self) { - if (!_textureCache) { - _textureCache = _glContext->cv_texture_cache(); - } - } - return _textureCache; -} - -- (mediapipe::GpuBufferMultiPool *)gpuBufferPool { - return _gpuBufferPool; -} - -- (id)mtlDevice { - @synchronized(self) { - if (!_mtlDevice) { - _mtlDevice = MTLCreateSystemDefaultDevice(); - } - } - return _mtlDevice; -} - -- (id)mtlCommandQueue { - @synchronized(self) { - if (!_mtlCommandQueue) { - _mtlCommandQueue = [self.mtlDevice newCommandQueue]; - } - } - return _mtlCommandQueue; -} - -#if COREVIDEO_SUPPORTS_METAL -- (CVMetalTextureCacheRef)mtlTextureCache { - @synchronized(self) { - if (!_mtlTextureCache) { - CVReturn __unused err = - CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); - NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); - // TODO: register and flush metal caches too. - } - } - return _mtlTextureCache; -} -#endif - -@end diff --git a/mediapipe/gpu/MPPGraphGPUDataTests.mm b/mediapipe/gpu/MPPGraphGPUDataTests.mm deleted file mode 100644 index e8b50845b..000000000 --- a/mediapipe/gpu/MPPGraphGPUDataTests.mm +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/port/threadpool.h" - -#import "mediapipe/gpu/MPPGraphGPUData.h" -#import "mediapipe/gpu/gpu_shared_data_internal.h" - -@interface MPPGraphGPUDataTests : XCTestCase { -} -@end - -@implementation MPPGraphGPUDataTests - -// This test verifies that the internal Objective-C object is correctly -// released when the C++ wrapper is released. -- (void)testCorrectlyReleased { - __weak id gpuData = nil; - std::weak_ptr gpuRes; - @autoreleasepool { - mediapipe::GpuSharedData gpu_shared; - gpuRes = gpu_shared.gpu_resources; - gpuData = gpu_shared.gpu_resources->ios_gpu_data(); - XCTAssertNotEqual(gpuRes.lock(), nullptr); - XCTAssertNotNil(gpuData); - } - XCTAssertEqual(gpuRes.lock(), nullptr); - XCTAssertNil(gpuData); -} - -// This test verifies that the lazy initialization of the glContext instance -// variable is thread-safe. All threads should read the same value. -- (void)testGlContextThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - EAGLContext* ogl_context[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &ogl_context, i] { - ogl_context[i] = gpu_shared.gpu_resources->ios_gpu_data().glContext; - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(ogl_context[i], ogl_context[i + 1]); - } -} - -// This test verifies that the lazy initialization of the textureCache instance -// variable is thread-safe. All threads should read the same value. -- (void)testTextureCacheThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - CFHolder texture_cache[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &texture_cache, i] { - texture_cache[i].reset(gpu_shared.gpu_resources->ios_gpu_data().textureCache); - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(*texture_cache[i], *texture_cache[i + 1]); - } -} - -@end diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index dc1e27a5c..1acf7cbfb 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,14 +14,15 @@ #import "mediapipe/gpu/MPPMetalHelper.h" -#import "mediapipe/gpu/MPPGraphGPUData.h" +#import "mediapipe/gpu/gpu_buffer.h" #import "mediapipe/gpu/graph_support.h" +#import "mediapipe/gpu/metal_shared_resources.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" @interface MPPMetalHelper () { - MPPGraphGPUData* _gpuShared; + mediapipe::GpuResources* _gpuResources; } @end @@ -46,7 +47,7 @@ class MetalHelperLegacySupport { - (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources { self = [super init]; if (self) { - _gpuShared = gpuResources->ios_gpu_data(); + _gpuResources = gpuResources; } return self; } @@ -111,19 +112,19 @@ class MetalHelperLegacySupport { } - (id)mtlDevice { - return _gpuShared.mtlDevice; + return _gpuResources->metal_shared().resources().mtlDevice; } - (id)mtlCommandQueue { - return _gpuShared.mtlCommandQueue; + return _gpuResources->metal_shared().resources().mtlCommandQueue; } - (CVMetalTextureCacheRef)mtlTextureCache { - return _gpuShared.mtlTextureCache; + return _gpuResources->metal_shared().resources().mtlTextureCache; } - (id)commandBuffer { - return [_gpuShared.mtlCommandQueue commandBuffer]; + return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; } - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer @@ -175,8 +176,9 @@ class MetalHelperLegacySupport { CVMetalTextureRef texture; CVReturn err = CVMetalTextureCacheCreateTextureFromImage( - NULL, _gpuShared.mtlTextureCache, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, - metalPixelFormat, width, height, plane, &texture); + NULL, _gpuResources->metal_shared().resources().mtlTextureCache, + mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, + &texture); CHECK_EQ(err, kCVReturnSuccess); return texture; } @@ -197,19 +199,20 @@ class MetalHelperLegacySupport { } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { - return _gpuShared.gpuBufferPool->GetBuffer(width, height); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height); } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height format:(mediapipe::GpuBufferFormat)format { - return _gpuShared.gpuBufferPool->GetBuffer(width, height, format); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format); } - (id)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { - return [_gpuShared.mtlDevice newLibraryWithFile:[[NSBundle bundleForClass:[self class]] - pathForResource:name ofType:@"metallib"] - error:error]; + return [_gpuResources->metal_shared().resources().mtlDevice + newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name + ofType:@"metallib"] + error:error]; } @end diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 91723a7d1..203a8dfd1 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -21,7 +21,7 @@ #include "mediapipe/gpu/graph_support.h" #if __APPLE__ -#import "mediapipe/gpu/MPPGraphGPUData.h" +#include "mediapipe/gpu/metal_shared_resources.h" #endif // __APPLE__ namespace mediapipe { @@ -97,15 +97,14 @@ GpuResources::GpuResources(std::shared_ptr gl_context) #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() - multiPool:&gpu_buffer_pool_]; + metal_shared_ = std::make_unique(); #endif // __APPLE__ } GpuResources::~GpuResources() { #if __APPLE__ - // Note: on Apple platforms, this object contains Objective-C objects. The - // destructor will release them, but ARC must be on. + // Note: on Apple platforms, this object contains Objective-C objects. + // The destructor will release them, but ARC must be on. #if !__has_feature(objc_arc) #error This file must be built with ARC. #endif @@ -196,10 +195,6 @@ GlContext::StatusOrGlContext GpuResources::GetOrCreateGlContext( GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} -#if __APPLE__ -MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } -#endif // __APPLE__ - extern const GraphService kGpuService; #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 4fe6ba04e..3f7c67e2e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -31,15 +31,14 @@ #ifdef __APPLE__ #include "mediapipe/gpu/cv_texture_cache_manager.h" -#ifdef __OBJC__ -@class MPPGraphGPUData; -#else -struct MPPGraphGPUData; -#endif // __OBJC__ #endif // defined(__APPLE__) namespace mediapipe { +#ifdef __APPLE__ +class MetalSharedResources; +#endif // defined(__APPLE__) + // TODO: rename to GpuService or GpuManager or something. class GpuResources { public: @@ -56,9 +55,7 @@ class GpuResources { // Shared GL context for calculators. // TODO: require passing a context or node identifier. - const std::shared_ptr& gl_context() { - return gl_context(nullptr); - }; + const std::shared_ptr& gl_context() { return gl_context(nullptr); } const std::shared_ptr& gl_context(CalculatorContext* cc); @@ -66,7 +63,7 @@ class GpuResources { GpuBufferMultiPool& gpu_buffer_pool() { return gpu_buffer_pool_; } #ifdef __APPLE__ - MPPGraphGPUData* ios_gpu_data(); + MetalSharedResources& metal_shared() { return *metal_shared_; } #endif // defined(__APPLE__)§ absl::Status PrepareGpuNode(CalculatorNode* node); @@ -96,8 +93,7 @@ class GpuResources { GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ - // Note that this is an Objective-C object. - MPPGraphGPUData* ios_gpu_data_; + std::unique_ptr metal_shared_; #endif // defined(__APPLE__) std::map> named_executors_; diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index d77692164..fafdfee8a 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -83,11 +83,11 @@ objc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/framework/port:threadpool", - "//mediapipe/gpu:MPPGraphGPUData", "//mediapipe/gpu:gl_base", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", + "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", "@com_google_absl//absl/base:core_headers", From bfa57310c4dfb43e9ea3d5b24059b7e042836911 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 10:17:46 -0800 Subject: [PATCH 126/469] Move TextPreprocessing to "processors" folder. PiperOrigin-RevId: 490532670 --- mediapipe/tasks/cc/components/BUILD | 43 ------------------- .../tasks/cc/components/processors/BUILD | 26 +++++++++++ .../cc/components/processors/proto/BUILD | 9 ++++ .../text_preprocessing_graph_options.proto | 2 +- .../text_preprocessing_graph.cc | 22 +++++----- .../text_preprocessing_graph.h | 30 +++++++------ mediapipe/tasks/cc/components/proto/BUILD | 9 ---- mediapipe/tasks/cc/text/text_classifier/BUILD | 4 +- .../text_classifier/text_classifier_graph.cc | 12 +++--- mediapipe/tasks/cc/text/text_embedder/BUILD | 4 +- .../text/text_embedder/text_embedder_graph.cc | 12 +++--- 11 files changed, 80 insertions(+), 93 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/BUILD rename mediapipe/tasks/cc/components/{ => processors}/proto/text_preprocessing_graph_options.proto (96%) rename mediapipe/tasks/cc/components/{ => processors}/text_preprocessing_graph.cc (94%) rename mediapipe/tasks/cc/components/{ => processors}/text_preprocessing_graph.h (67%) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD deleted file mode 100644 index 54a5207d2..000000000 --- a/mediapipe/tasks/cc/components/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -# TODO: Investigate rewriting the build rule to only link -# the Bert Preprocessor if it's needed. -cc_library( - name = "text_preprocessing_graph", - srcs = ["text_preprocessing_graph.cc"], - hdrs = ["text_preprocessing_graph.h"], - deps = [ - "//mediapipe/calculators/tensor:bert_preprocessor_calculator", - "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:text_to_tensor_calculator", - "//mediapipe/framework:subgraph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], - alwayslink = 1, -) diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 4946683f5..185bf231b 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -133,3 +133,29 @@ cc_library( ) # TODO: Enable this test + +# TODO: Investigate rewriting the build rule to only link +# the Bert Preprocessor if it's needed. +cc_library( + name = "text_preprocessing_graph", + srcs = ["text_preprocessing_graph.cc"], + hdrs = ["text_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:bert_preprocessor_calculator", + "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 9c58a8585..f48c4bad8 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -59,3 +59,12 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", ], ) + +mediapipe_proto_library( + name = "text_preprocessing_graph_options_proto", + srcs = ["text_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto similarity index 96% rename from mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto rename to mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index 926e3d7fb..a67cfd8a9 100644 --- a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc similarity index 94% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.cc rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index 6aad8fdd5..de16375bd 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include @@ -25,13 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -41,7 +42,8 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions; +using ::mediapipe::tasks::components::processors::proto:: + TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; @@ -169,7 +171,7 @@ absl::StatusOr GetMaxSeqLen(const tflite::SubGraph& model_graph) { } } // namespace -absl::Status ConfigureTextPreprocessingSubgraph( +absl::Status ConfigureTextPreprocessingGraph( const ModelResources& model_resources, TextPreprocessingGraphOptions& options) { if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) { @@ -200,8 +202,7 @@ absl::Status ConfigureTextPreprocessingSubgraph( return absl::OkStatus(); } -// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text -// preprocessing. +// A TextPreprocessingGraph performs text preprocessing. // - Accepts a std::string input and outputs CPU tensors. // // Inputs: @@ -216,9 +217,9 @@ absl::Status ConfigureTextPreprocessingSubgraph( // Vector containing the preprocessed input tensors for the TFLite model. // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureTextPreprocessing()' function. See header file for more -// details. -class TextPreprocessingSubgraph : public mediapipe::Subgraph { +// using the 'ConfigureTextPreprocessingGraph()' function. See header file for +// more details. +class TextPreprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -267,8 +268,9 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::TextPreprocessingSubgraph); + ::mediapipe::tasks::components::processors::TextPreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.h b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h similarity index 67% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.h rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h index b031a5550..43d57be29 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h @@ -13,26 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" -// Configures a TextPreprocessing subgraph using the provided `model_resources` +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a TextPreprocessingGraph using the provided `model_resources` // and TextPreprocessingGraphOptions. // - Accepts a std::string input and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.TextPreprocessingSubgraph"); // MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph( // model_resources, // &preprocessing.GetOptions())); // -// The resulting TextPreprocessing subgraph has the following I/O: +// The resulting TextPreprocessingGraph has the following I/O: // Inputs: // TEXT - std::string // The text to preprocess. @@ -43,16 +48,13 @@ limitations under the License. // Outputs: // TENSORS - std::vector // Vector containing the preprocessed input tensors for the TFLite model. -namespace mediapipe { -namespace tasks { -namespace components { - -absl::Status ConfigureTextPreprocessingSubgraph( - const tasks::core::ModelResources& model_resources, - tasks::components::proto::TextPreprocessingGraphOptions& options); +absl::Status ConfigureTextPreprocessingGraph( + const core::ModelResources& model_resources, + proto::TextPreprocessingGraphOptions& options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 4534a1652..569023753 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -22,12 +22,3 @@ mediapipe_proto_library( name = "segmenter_options_proto", srcs = ["segmenter_options.proto"], ) - -mediapipe_proto_library( - name = "text_preprocessing_graph_options_proto", - srcs = ["text_preprocessing_graph_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 01adc9fc3..61395cf4e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -52,11 +52,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_calculator", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 9a7dce1aa..3be92f309 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -25,8 +25,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -115,12 +115,12 @@ class TextClassifierGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index 27c9cb730..f19af35be 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -54,11 +54,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index c54636ee2..225ef07bd 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -107,12 +107,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. From 41a7f9d7d6fdc0bfd1c9e7d4cc00532512474de2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 15:23:02 -0800 Subject: [PATCH 127/469] Internal change PiperOrigin-RevId: 490595529 --- mediapipe/web/graph_runner/graph_runner.ts | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index c4654794c..378bc0a4d 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -176,10 +176,14 @@ export class GraphRunner { if (glCanvas !== undefined) { this.wasmModule.canvas = glCanvas; - } else { + } else if (typeof OffscreenCanvas !== 'undefined') { // If no canvas is provided, assume Chrome/Firefox and just make an // OffscreenCanvas for GPU processing. this.wasmModule.canvas = new OffscreenCanvas(1, 1); + } else { + console.warn('OffscreenCanvas not detected and GraphRunner constructor ' + + 'glCanvas parameter is undefined. Creating backup canvas.'); + this.wasmModule.canvas = document.createElement('canvas'); } } From 0bdb48ceb18a772158b92793daf6ac4bf8ce6f76 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 16:17:02 -0800 Subject: [PATCH 128/469] Use kUtilityFramebuffer in GlCalculatorHelper All calculators using the same context can share a single framebuffer object. PiperOrigin-RevId: 490605074 --- mediapipe/gpu/gl_calculator_helper.cc | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index 7d317e0f1..9b217ddfd 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -27,19 +27,7 @@ namespace mediapipe { GlCalculatorHelper::GlCalculatorHelper() {} -GlCalculatorHelper::~GlCalculatorHelper() { - if (!Initialized()) return; - RunInGlContext( - [this] { - if (framebuffer_) { - glDeleteFramebuffers(1, &framebuffer_); - framebuffer_ = 0; - } - return absl::OkStatus(); - }, - /*calculator_context=*/nullptr) - .IgnoreError(); -} +GlCalculatorHelper::~GlCalculatorHelper() {} void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, GpuResources* gpu_resources) { @@ -125,9 +113,9 @@ void GlCalculatorHelper::CreateFramebuffer() { // Our framebuffer will have a color attachment but no depth attachment, // so it's important that the depth test be off. It is disabled by default, // but we wanted to be explicit. - // TODO: move this to glBindFramebuffer? + // TODO: move this to glBindFramebuffer? Or just remove. glDisable(GL_DEPTH_TEST); - glGenFramebuffers(1, &framebuffer_); + framebuffer_ = kUtilityFramebuffer.Get(*gl_context_); } void GlCalculatorHelper::BindFramebuffer(const GlTexture& dst) { From 395d9d8ea21c93bbefb37ad980ad41f66b9a2f9f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 27 Nov 2022 00:05:08 -0800 Subject: [PATCH 129/469] Instantiate GetDetectionVectorItemCalculator variant of GetVectorItemCalculator<>. PiperOrigin-RevId: 491123314 --- mediapipe/calculators/core/BUILD | 1 + mediapipe/calculators/core/get_vector_item_calculator.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 39837fadb..3b658eb5b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1299,6 +1299,7 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc index 51fb46b98..3306e4ff3 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -15,6 +15,7 @@ #include "mediapipe/calculators/core/get_vector_item_calculator.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" namespace mediapipe { @@ -32,5 +33,9 @@ using GetClassificationListVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); +using GetDetectionVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetDetectionVectorItemCalculator); + } // namespace api2 } // namespace mediapipe From 153edc59a111c12b940169a272b36772fcd519a1 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 28 Nov 2022 09:52:40 -0800 Subject: [PATCH 130/469] Add support for browsers without SIMD PiperOrigin-RevId: 491371277 --- mediapipe/tasks/web/BUILD | 12 ++ mediapipe/tasks/web/audio.ts | 5 +- mediapipe/tasks/web/audio/BUILD | 1 + .../tasks/web/audio/audio_classifier/BUILD | 2 +- .../audio_classifier/audio_classifier.ts | 41 ++---- .../audio/audio_embedder/audio_embedder.ts | 28 ++-- mediapipe/tasks/web/audio/index.ts | 1 + mediapipe/tasks/web/core/BUILD | 9 +- mediapipe/tasks/web/core/fileset_resolver.ts | 130 ++++++++++++++++++ mediapipe/tasks/web/core/task_runner.ts | 45 +++++- ..._loader_options.d.ts => wasm_fileset.d.ts} | 4 +- mediapipe/tasks/web/text.ts | 5 +- mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/index.ts | 1 + .../tasks/web/text/text_classifier/BUILD | 1 - .../text/text_classifier/text_classifier.ts | 39 ++---- mediapipe/tasks/web/text/text_embedder/BUILD | 1 - .../web/text/text_embedder/text_embedder.ts | 42 ++---- mediapipe/tasks/web/vision.ts | 4 +- mediapipe/tasks/web/vision/BUILD | 1 + .../gesture_recognizer/gesture_recognizer.ts | 46 +++---- .../vision/hand_landmarker/hand_landmarker.ts | 46 +++---- .../image_classifier/image_classifier.ts | 41 ++---- .../vision/image_embedder/image_embedder.ts | 40 ++---- mediapipe/tasks/web/vision/index.ts | 1 + .../vision/object_detector/object_detector.ts | 40 ++---- mediapipe/web/graph_runner/graph_runner.ts | 8 +- third_party/wasm_files.bzl | 76 +++++++--- 28 files changed, 410 insertions(+), 261 deletions(-) create mode 100644 mediapipe/tasks/web/core/fileset_resolver.ts rename mediapipe/tasks/web/core/{wasm_loader_options.d.ts => wasm_fileset.d.ts} (88%) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 7e5d02892..20e717433 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -13,10 +13,16 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_files(srcs = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", ]) # Audio @@ -57,6 +63,8 @@ pkg_npm( deps = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", ":audio_bundle", ], ) @@ -99,6 +107,8 @@ pkg_npm( deps = [ "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", ":text_bundle", ], ) @@ -141,6 +151,8 @@ pkg_npm( deps = [ "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", ":vision_bundle", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 8c522efcc..2f4fb0315 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,11 +14,12 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/index'; +import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl, FilesetResolver as FilesetResolverImpl} from '../../tasks/web/audio/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. const AudioClassifier = AudioClassifierImpl; const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; -export {AudioClassifier, AudioEmbedder}; +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index acd7494d7..d08602521 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -10,5 +10,6 @@ mediapipe_ts_library( deps = [ "//mediapipe/tasks/web/audio/audio_classifier", "//mediapipe/tasks/web/audio/audio_embedder", + "//mediapipe/tasks/web/core:fileset_resolver", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 498b17845..c419d3b98 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/tasks/web/core:task_runner", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 20c745383..e606019f2 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; @@ -50,28 +50,17 @@ export class AudioClassifier extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioClassifierOptions The options for the audio classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - audioClassifierOptions: AudioClassifierOptions): + wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - AudioClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await TaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset); await classifier.setOptions(audioClassifierOptions); return classifier; } @@ -79,31 +68,31 @@ export class AudioClassifier extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return AudioClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return AudioClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 9dce02862..c87aceabe 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -24,7 +24,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -52,25 +52,25 @@ export class AudioEmbedder extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioEmbedderOptions The options for the audio embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, audioEmbedderOptions: AudioEmbedderOptions): Promise { // Create a file locator based on the loader options const fileLocator: FileLocator = { locateFile() { // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); + return wasmFileset.wasmBinaryPath.toString(); } }; const embedder = await createMediaPipeLib( - AudioEmbedder, wasmLoaderOptions.wasmLoaderPath, + AudioEmbedder, wasmFileset.wasmLoaderPath, /* assetLoaderScript= */ undefined, /* glCanvas= */ undefined, fileLocator); await embedder.setOptions(audioEmbedderOptions); @@ -80,31 +80,31 @@ export class AudioEmbedder extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return AudioEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return AudioEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 17a908f30..dbad8c617 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -16,3 +16,4 @@ export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 6eca8bb4a..d709e3409 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -8,7 +8,7 @@ mediapipe_ts_declaration( name = "core", srcs = [ "base_options.d.ts", - "wasm_loader_options.d.ts", + "wasm_fileset.d.ts", ], ) @@ -18,12 +18,19 @@ mediapipe_ts_library( "task_runner.ts", ], deps = [ + ":core", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", ], ) +mediapipe_ts_library( + name = "fileset_resolver", + srcs = ["fileset_resolver.ts"], + deps = [":core"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts new file mode 100644 index 000000000..7d68dbc16 --- /dev/null +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -0,0 +1,130 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Placeholder for internal dependency on trusted resource URL builder + +import {WasmFileset} from './wasm_fileset'; + +let supportsSimd: boolean|undefined; + +/** + * Simple WASM program to test compatibility with the M91 instruction set. + * Compiled from + * https://github.com/GoogleChromeLabs/wasm-feature-detect/blob/main/src/detectors/simd/module.wat + */ +const WASM_SIMD_CHECK = new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 5, 1, 96, 0, 1, 123, 3, + 2, 1, 0, 10, 10, 1, 8, 0, 65, 0, 253, 15, 253, 98, 11 +]); + +async function isSimdSupported(): Promise { + if (supportsSimd === undefined) { + try { + await WebAssembly.instantiate(WASM_SIMD_CHECK); + supportsSimd = true; + } catch { + supportsSimd = false; + } + } + + return supportsSimd; +} + +async function createFileset( + taskName: string, basePath: string = '.'): Promise { + if (await isSimdSupported()) { + return { + wasmLoaderPath: + `/${basePath}/${taskName}_wasm_internal.js`, + wasmBinaryPath: + `/${basePath}/${taskName}_wasm_internal.wasm`, + }; + } else { + return { + wasmLoaderPath: + `/${basePath}/${taskName}_wasm_nosimd_internal.js`, + wasmBinaryPath: `/${basePath}/${ + taskName}_wasm_nosimd_internal.wasm`, + }; + } +} + +// tslint:disable:class-as-namespace + +/** + * Resolves the files required for the MediaPipe Task APIs. + * + * This class verifies whether SIMD is supported in the current environment and + * loads the SIMD files only if support is detected. The returned filesets + * require that the Wasm files are published without renaming. If this is not + * possible, you can invoke the MediaPipe Tasks APIs using a manually created + * `WasmFileset`. + */ +export class FilesetResolver { + /** + * Returns whether SIMD is supported in the current environment. + * + * If your environment requires custom locations for the MediaPipe Wasm files, + * you can use `isSimdSupported()` to decide whether to load the SIMD-based + * assets. + * + * @return Whether SIMD support was detected in the current environment. + */ + static isSimdSupported(): Promise { + return isSimdSupported(); + } + + /** + * Creates a fileset for the MediaPipe Audio tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Audio + * tasks. + */ + static forAudioTasks(basePath?: string): Promise { + return createFileset('audio', basePath); + } + + /** + * Creates a fileset for the MediaPipe Text tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Text + * tasks. + */ + static forTextTasks(basePath?: string): Promise { + return createFileset('text', basePath); + } + + /** + * Creates a fileset for the MediaPipe Vision tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Vision + * tasks. + */ + static forVisionTasks(basePath?: string): Promise { + return createFileset('vision', basePath); + } +} + + diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 67aa4e4df..4085be697 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,14 @@ * limitations under the License. */ -import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; +import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; -import {GraphRunner, WasmModule} from '../../../web/graph_runner/graph_runner'; +import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; + +import {WasmFileset} from './wasm_fileset'; + +// None of the MP Tasks ship bundle assets. +const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing const WasmMediaPipeImageLib = @@ -26,8 +31,40 @@ const WasmMediaPipeImageLib = export abstract class TaskRunner extends WasmMediaPipeImageLib { private processingErrors: Error[] = []; - constructor(wasmModule: WasmModule) { - super(wasmModule); + /** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ + protected static async createInstance( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + if (initializeCanvas) { + // Fall back to an OffscreenCanvas created by the GraphRunner if + // OffscreenCanvas is available + const canvas = typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined; + return createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + } else { + return createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null, + fileLocator); + } + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. diff --git a/mediapipe/tasks/web/core/wasm_loader_options.d.ts b/mediapipe/tasks/web/core/wasm_fileset.d.ts similarity index 88% rename from mediapipe/tasks/web/core/wasm_loader_options.d.ts rename to mediapipe/tasks/web/core/wasm_fileset.d.ts index 74436583d..18227eab9 100644 --- a/mediapipe/tasks/web/core/wasm_loader_options.d.ts +++ b/mediapipe/tasks/web/core/wasm_fileset.d.ts @@ -16,8 +16,8 @@ // Placeholder for internal dependency on trusted resource url -/** An object containing the locations of all Wasm assets */ -export declare interface WasmLoaderOptions { +/** An object containing the locations of the Wasm assets */ +export declare interface WasmFileset { /** The path to the Wasm loader script. */ wasmLoaderPath: string; /** The path to the Wasm binary. */ diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index 8f15075c5..0636714b8 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,11 +14,12 @@ * limitations under the License. */ -import {TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; +import {FilesetResolver as FilesetResolverImpl, TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. +const FilesetResolver = FilesetResolverImpl; const TextClassifier = TextClassifierImpl; const TextEmbedder = TextEmbedderImpl; -export {TextClassifier, TextEmbedder}; +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 4b465b0f5..159db1a0d 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -8,6 +8,7 @@ mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", "//mediapipe/tasks/web/text/text_embedder", ], diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index d50db209c..a28e4dd1c 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -16,3 +16,4 @@ export * from '../../../tasks/web/text/text_classifier/text_classifier'; export * from '../../../tasks/web/text/text_embedder/text_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 71ef02c92..f3d272daa 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -26,7 +26,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 04789f5e1..197869a36 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -22,8 +22,7 @@ import {convertBaseOptionsToProto} from '../../../../tasks/web/components/proces import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; @@ -48,27 +47,17 @@ export class TextClassifier extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textClassifierOptions The options for the text classifier. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - TextClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset); await classifier.setOptions(textClassifierOptions); return classifier; } @@ -76,31 +65,31 @@ export class TextClassifier extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text classifier based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return TextClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text classifier based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return TextClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } /** diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 3f92b8ae1..b858f6b83 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -26,7 +26,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 2042a0985..511fd2411 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -24,8 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; @@ -52,27 +51,17 @@ export class TextEmbedder extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textEmbedderOptions The options for the text embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, textEmbedderOptions: TextEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - TextEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const embedder = await TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset); await embedder.setOptions(textEmbedderOptions); return embedder; } @@ -80,31 +69,31 @@ export class TextEmbedder extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return TextEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return TextEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } /** @@ -122,14 +111,11 @@ export class TextEmbedder extends TaskRunner { options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } - this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); } - /** * Performs embeding extraction on the provided text and waits synchronously * for the response. diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 74a056464..f1ced59af 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,10 +14,11 @@ * limitations under the License. */ -import {GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; +import {FilesetResolver as FilesetResolverImpl, GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. +const FilesetResolver = FilesetResolverImpl; const GestureRecognizer = GestureRecognizerImpl; const HandLandmarker = HandLandmarkerImpl; const ImageClassifier = ImageClassifierImpl; @@ -25,6 +26,7 @@ const ImageEmbedder = ImageEmbedderImpl; const ObjectDetector = ObjectDetectorImpl; export { + FilesetResolver, GestureRecognizer, HandLandmarker, ImageClassifier, diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 3c45fbfa6..42bc0a494 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -8,6 +8,7 @@ mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index dd050d0f1..7441911c1 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -29,9 +29,9 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; @@ -82,28 +82,18 @@ export class GestureRecognizer extends /** * Initializes the Wasm runtime and creates a new gesture recognizer from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param gestureRecognizerOptions The options for the gesture recognizer. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const recognizer = await createMediaPipeLib( - GestureRecognizer, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const recognizer = await VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); await recognizer.setOptions(gestureRecognizerOptions); return recognizer; } @@ -111,35 +101,37 @@ export class GestureRecognizer extends /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return GestureRecognizer.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return GestureRecognizer.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new GestureRecognizerGraphOptions(); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 32b1eed4b..6d69d568c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -25,9 +25,9 @@ import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landm import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {HandLandmarkerOptions} from './hand_landmarker_options'; @@ -71,27 +71,17 @@ export class HandLandmarker extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new `HandLandmarker` from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param handLandmarkerOptions The options for the HandLandmarker. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, handLandmarkerOptions: HandLandmarkerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const landmarker = await createMediaPipeLib( - HandLandmarker, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const landmarker = await VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset); await landmarker.setOptions(handLandmarkerOptions); return landmarker; } @@ -99,35 +89,37 @@ export class HandLandmarker extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return HandLandmarker.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return HandLandmarker.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new HandLandmarkerGraphOptions(); this.handLandmarksDetectorGraphOptions = diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index b59cb6fb1..604795f9f 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -21,9 +21,9 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -49,28 +49,17 @@ export class ImageClassifier extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location + * Wasm binary and its loader. * @param imageClassifierOptions The options for the image classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - imageClassifierOptions: ImageClassifierOptions): + wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - ImageClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset); await classifier.setOptions(imageClassifierOptions); return classifier; } @@ -78,31 +67,31 @@ export class ImageClassifier extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ImageClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ImageClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index f96f1e961..68068db6d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -23,9 +23,9 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; @@ -51,27 +51,17 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param imageEmbedderOptions The options for the image embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, imageEmbedderOptions: ImageEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - ImageEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const embedder = await VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); await embedder.setOptions(imageEmbedderOptions); return embedder; } @@ -79,31 +69,31 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ImageEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ImageEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index d68c00cc7..0337a0f2f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -19,3 +19,4 @@ export * from '../../../tasks/web/vision/image_embedder/image_embedder'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/object_detector/object_detector'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 44046cd1e..0f039acb2 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -19,9 +19,9 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; @@ -48,27 +48,17 @@ export class ObjectDetector extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new object detector from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param objectDetectorOptions The options for the Object Detector. Note that * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const detector = await createMediaPipeLib( - ObjectDetector, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const detector = await VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset); await detector.setOptions(objectDetectorOptions); return detector; } @@ -76,31 +66,31 @@ export class ObjectDetector extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new object detector based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ObjectDetector.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new object detector based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ObjectDetector.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 378bc0a4d..9a0f7148c 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -133,9 +133,11 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing GraphRunner and -// subclasses. -type WasmMediaPipeConstructor = +/** + * Internal type of constructors used for initializing GraphRunner and + * subclasses. + */ +export type WasmMediaPipeConstructor = (new ( module: WasmModule, canvas?: HTMLCanvasElement|OffscreenCanvas|null) => LibType); diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 6bfde21ba..504f8567a 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,36 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"], - ) - - http_file( - name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"], - ) - - http_file( - name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"], + sha256 = "42d2d0ade6e2e8b81425b23686be93eb1423b7777f043eb8f18ad671e2ca803f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1669173769507080"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"], + sha256 = "20200ee9b0866d5176f633a9b375e8a44e53204c01ea2e159e2f9245afb00e80", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1669173772528997"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", + sha256 = "11bbf73d48723b19a5a6a13ec296ecdb2aa178cdc3db9d7bc54265a7d4b94c6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1669173774625527"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", + sha256 = "d4528972219033996a83a62798952b6ee8b6b396bcffd96fd5bda5458d57d3a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1669173777474822"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_internal_js", + sha256 = "29e72e177122f92bda6a3ecd463ebacf30b920559b06c97068112a22eeea4d0e", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1669173779706893"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"], + sha256 = "84e5f5ac70f7718baeaa09a89b155abbea67386e7d50663301b3af7ef0941e74", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1669173782728605"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", + sha256 = "36f247673124e32535f217265b96508c1badee8fe2458c11c1efa95b6bec5daa", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1669173785027190"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", + sha256 = "cc74d90a8aaf6d006ec24048cc80c33f96baeeb0075a6c6739f30d41da54e450", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1669173787903754"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_internal_js", + sha256 = "c3451423186766b08008e07ef6d52f628fcc0aca75beedd9bb4d87d380f29edd", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1669173790070986"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"], + sha256 = "d1e8ad748913e3f190bfd3f72e0e8a4a308f78b918d54c79cec60a2cf30a49f0", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1669173792993881"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", + sha256 = "e5f1b5e8264ff9a90371653cb0fdbf9ce3b30b712acbd72068af18ebca2293ac", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1669173794969702"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", + sha256 = "24351fe580e88f2065b1978b8b3c0f3ad7b90f1c95805aafa07971ce422b5854", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1669173797596874"], ) From c48ca1f674e2fef6b23a28100fd092ebe656e96a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 13:29:35 -0800 Subject: [PATCH 131/469] internal change PiperOrigin-RevId: 491429214 --- .../tasks/cc/components/containers/BUILD | 5 --- .../tasks/cc/vision/hand_landmarker/BUILD | 6 +++ .../hand_landmarker/hand_landmark.h} | 10 ++--- .../tasks/components/containers/BUILD | 12 ------ .../com/google/mediapipe/tasks/vision/BUILD | 2 + .../handlandmarker}/HandLandmark.java | 2 +- .../python/components/containers/landmark.py | 26 ------------ .../tasks/python/vision/hand_landmarker.py | 26 ++++++++++++ .../web/components/containers/landmark.d.ts | 25 ----------- .../tasks/web/vision/hand_landmarker/BUILD | 1 + .../vision/hand_landmarker/hand_landmark.d.ts | 41 +++++++++++++++++++ 11 files changed, 82 insertions(+), 74 deletions(-) rename mediapipe/tasks/cc/{components/containers/landmark.h => vision/hand_landmarker/hand_landmark.h} (78%) rename mediapipe/tasks/java/com/google/mediapipe/tasks/{components/containers => vision/handlandmarker}/HandLandmark.java (97%) create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index dec977fb8..35d3f4785 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -49,8 +49,3 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) - -cc_library( - name = "landmark", - hdrs = ["landmark.h"], -) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 46948ee6c..03ec45f7d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -54,6 +54,12 @@ cc_library( ], ) +cc_library( + name = "hand_landmark", + hdrs = ["hand_landmark.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h similarity index 78% rename from mediapipe/tasks/cc/components/containers/landmark.h rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h index 6fdd294ae..c8dbc9254 100644 --- a/mediapipe/tasks/cc/components/containers/landmark.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ -namespace mediapipe::tasks::components::containers { +namespace mediapipe::tasks::vision::hand_landmarker { // The 21 hand landmarks. enum HandLandmark { @@ -43,6 +43,6 @@ enum HandLandmark { PINKY_TIP = 20 }; -} // namespace mediapipe::tasks::components::containers +} // namespace mediapipe::tasks::vision::hand_landmarker -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 869157295..d6e6ac740 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -74,18 +74,6 @@ android_library( ], ) -android_library( - name = "handlandmark", - srcs = ["HandLandmark.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "@maven//:androidx_annotation_annotation", - "@maven//:com_google_guava_guava", - ], -) - android_library( name = "landmark", srcs = ["Landmark.java"], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 72cee133f..b7febb118 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -145,6 +145,7 @@ android_library( android_library( name = "handlandmarker", srcs = [ + "handlandmarker/HandLandmark.java", "handlandmarker/HandLandmarker.java", "handlandmarker/HandLandmarkerResult.java", ], @@ -168,6 +169,7 @@ android_library( "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java similarity index 97% rename from mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java rename to mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java index da7c4e0ca..7b21ebddf 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package com.google.mediapipe.tasks.components.containers; +package com.google.mediapipe.tasks.vision.handlandmarker; import androidx.annotation.IntDef; diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index 81b2943dc..dee2a16ad 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -14,7 +14,6 @@ """Landmark data class.""" import dataclasses -import enum from typing import Optional from mediapipe.framework.formats import landmark_pb2 @@ -121,28 +120,3 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) - - -class HandLandmark(enum.IntEnum): - """The 21 hand landmarks.""" - WRIST = 0 - THUMB_CMC = 1 - THUMB_MCP = 2 - THUMB_IP = 3 - THUMB_TIP = 4 - INDEX_FINGER_MCP = 5 - INDEX_FINGER_PIP = 6 - INDEX_FINGER_DIP = 7 - INDEX_FINGER_TIP = 8 - MIDDLE_FINGER_MCP = 9 - MIDDLE_FINGER_PIP = 10 - MIDDLE_FINGER_DIP = 11 - MIDDLE_FINGER_TIP = 12 - RING_FINGER_MCP = 13 - RING_FINGER_PIP = 14 - RING_FINGER_DIP = 15 - RING_FINGER_TIP = 16 - PINKY_MCP = 17 - PINKY_PIP = 18 - PINKY_DIP = 19 - PINKY_TIP = 20 diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index 3367f1da7..a0cd99a83 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -14,6 +14,7 @@ """MediaPipe hand landmarker task.""" import dataclasses +import enum from typing import Callable, Mapping, Optional, List from mediapipe.framework.formats import classification_pb2 @@ -53,6 +54,31 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 + + @dataclasses.dataclass class HandLandmarkerResult: """The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image. diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index 352717a2f..c887303d0 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -33,28 +33,3 @@ export declare interface Landmark { /** Whether this landmark is normalized with respect to the image size. */ normalized: boolean; } - -/** The 21 hand landmarks. */ -export const enum HandLandmark { - WRIST = 0, - THUMB_CMC = 1, - THUMB_MCP = 2, - THUMB_IP = 3, - THUMB_TIP = 4, - INDEX_FINGER_MCP = 5, - INDEX_FINGER_PIP = 6, - INDEX_FINGER_DIP = 7, - INDEX_FINGER_TIP = 8, - MIDDLE_FINGER_MCP = 9, - MIDDLE_FINGER_PIP = 10, - MIDDLE_FINGER_DIP = 11, - MIDDLE_FINGER_TIP = 12, - RING_FINGER_MCP = 13, - RING_FINGER_PIP = 14, - RING_FINGER_DIP = 15, - RING_FINGER_TIP = 16, - PINKY_MCP = 17, - PINKY_PIP = 18, - PINKY_DIP = 19, - PINKY_TIP = 20 -} diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 1849687c5..fc3e6ef1f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -34,6 +34,7 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "hand_landmarker_types", srcs = [ + "hand_landmark.d.ts", "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts new file mode 100644 index 000000000..ca2543f78 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} From 342f95fa2044c4957ea7cb65352268a868e3d680 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 28 Nov 2022 13:51:59 -0800 Subject: [PATCH 132/469] Typo fix PiperOrigin-RevId: 491434987 --- mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h | 2 +- mediapipe/tasks/python/vision/image_segmenter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 43bf5b7e6..511d3b9c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -98,7 +98,7 @@ struct ImageSegmenterOptions { // - list of segmented masks. // - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. // - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `cahnnels`. +// `channels`. // - batch is always 1 // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 9ef911f75..62fc8bb7c 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -110,7 +110,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): - list of segmented masks. - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. - if `output_type` is CONFIDENCE_MASK, float32 Image list of size - `cahnnels`. + `channels`. - batch is always 1 An example of such model can be found at: From b65c40b302ccf397d6da3c27ab2795335e5c63cd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 14:15:16 -0800 Subject: [PATCH 133/469] Internal change PiperOrigin-RevId: 491441446 --- mediapipe/objc/MPPLayerRenderer.m | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/objc/MPPLayerRenderer.m b/mediapipe/objc/MPPLayerRenderer.m index 7c3027fb6..edd2216ee 100644 --- a/mediapipe/objc/MPPLayerRenderer.m +++ b/mediapipe/objc/MPPLayerRenderer.m @@ -54,10 +54,11 @@ glGenRenderbuffers(1, &renderbuffer_); glBindRenderbuffer(GL_RENDERBUFFER, renderbuffer_); glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_RENDERBUFFER, renderbuffer_); - BOOL success = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER fromDrawable:_layer]; + BOOL success __unused = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER + fromDrawable:_layer]; NSAssert(success, @"could not create renderbuffer storage for layer with bounds %@", NSStringFromCGRect(_layer.bounds)); - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + GLenum status __unused = glCheckFramebufferStatus(GL_FRAMEBUFFER); NSAssert(status == GL_FRAMEBUFFER_COMPLETE, @"failed to make complete framebuffer object %x", status); } From 26a7ca5c64cd885978677931a7218d33cd7d1dec Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 15:02:55 -0800 Subject: [PATCH 134/469] fix typo and minor formatting issues PiperOrigin-RevId: 491453662 --- mediapipe/python/solutions/drawing_utils.py | 42 ++++++++++----------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index bebcbe97c..1b8b173f7 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MediaPipe solution drawing utils.""" import math @@ -135,15 +134,14 @@ def draw_landmarks( the image. connections: A list of landmark index tuples that specifies how landmarks to be connected in the drawing. - landmark_drawing_spec: Either a DrawingSpec object or a mapping from - hand landmarks to the DrawingSpecs that specifies the landmarks' drawing - settings such as color, line thickness, and circle radius. - If this argument is explicitly set to None, no landmarks will be drawn. - connection_drawing_spec: Either a DrawingSpec object or a mapping from - hand connections to the DrawingSpecs that specifies the - connections' drawing settings such as color and line thickness. - If this argument is explicitly set to None, no landmark connections will - be drawn. + landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand + landmarks to the DrawingSpecs that specifies the landmarks' drawing + settings such as color, line thickness, and circle radius. If this + argument is explicitly set to None, no landmarks will be drawn. + connection_drawing_spec: Either a DrawingSpec object or a mapping from hand + connections to the DrawingSpecs that specifies the connections' drawing + settings such as color and line thickness. If this argument is explicitly + set to None, no landmark connections will be drawn. Raises: ValueError: If one of the followings: @@ -197,14 +195,13 @@ def draw_landmarks( drawing_spec.color, drawing_spec.thickness) -def draw_axis( - image: np.ndarray, - rotation: np.ndarray, - translation: np.ndarray, - focal_length: Tuple[float, float] = (1.0, 1.0), - principal_point: Tuple[float, float] = (0.0, 0.0), - axis_length: float = 0.1, - axis_drawing_spec: DrawingSpec = DrawingSpec()): +def draw_axis(image: np.ndarray, + rotation: np.ndarray, + translation: np.ndarray, + focal_length: Tuple[float, float] = (1.0, 1.0), + principal_point: Tuple[float, float] = (0.0, 0.0), + axis_length: float = 0.1, + axis_drawing_spec: DrawingSpec = DrawingSpec()): """Draws the 3D axis on the image. Args: @@ -214,8 +211,8 @@ def draw_axis( focal_length: camera focal length along x and y directions. principal_point: camera principal point in x and y. axis_length: length of the axis in the drawing. - axis_drawing_spec: A DrawingSpec object that specifies the xyz axis - drawing settings such as line thickness. + axis_drawing_spec: A DrawingSpec object that specifies the xyz axis drawing + settings such as line thickness. Raises: ValueError: If one of the followings: @@ -226,7 +223,7 @@ def draw_axis( image_rows, image_cols, _ = image.shape # Create axis points in camera coordinate frame. axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) - axis_cam = np.matmul(rotation, axis_length*axis_world.T).T + translation + axis_cam = np.matmul(rotation, axis_length * axis_world.T).T + translation x = axis_cam[..., 0] y = axis_cam[..., 1] z = axis_cam[..., 2] @@ -274,8 +271,9 @@ def plot_landmarks(landmark_list: landmark_pb2.NormalizedLandmarkList, connections' drawing settings such as color and line thickness. elevation: The elevation from which to view the plot. azimuth: the azimuth angle to rotate the plot. + Raises: - ValueError: If any connetions contain invalid landmark index. + ValueError: If any connection contains an invalid landmark index. """ if not landmark_list: return From 7b74fd53f592ab115f60180278952eafeeb61634 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 15:46:30 -0800 Subject: [PATCH 135/469] Verify that kernel cache is only used when OpenCL is active PiperOrigin-RevId: 491463306 --- .../calculators/tensor/inference_calculator_gl_advanced.cc | 6 +++--- mediapipe/calculators/tflite/tflite_inference_calculator.cc | 6 +++--- mediapipe/util/tflite/tflite_gpu_runner.h | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index c2c723402..b226dbbd8 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -258,9 +258,9 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - gpu_runner->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + gpu_runner->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index afdc9ed6f..0f7fa933e 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -485,9 +485,9 @@ absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - tflite_gpu_runner_->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index dfbc8d659..5eeaa230f 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -21,6 +21,7 @@ #include "absl/status/status.h" #include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -89,7 +90,8 @@ class TFLiteGPURunner { serialized_binary_cache_ = std::move(cache); } - std::vector GetSerializedBinaryCache() { + absl::StatusOr> GetSerializedBinaryCache() { + RET_CHECK(cl_environment_) << "CL environment is not initialized."; return cl_environment_->GetSerializedBinaryCache(); } From e987b69f397af3d7bb4976d4e77029dacaae999a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 16:48:17 -0800 Subject: [PATCH 136/469] Add alternative method to determine unique kernel cache path PiperOrigin-RevId: 491476293 --- .../tensor/inference_calculator_gl_advanced.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index b226dbbd8..8fd55efa7 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -236,14 +236,21 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& gpu_delegate_options) { - use_kernel_caching_ = gpu_delegate_options.has_cached_kernel_path(); + // The kernel cache needs a unique filename based on either model_path or the + // model token, to prevent the cache from being overwritten if the graph has + // more than one model. + use_kernel_caching_ = + gpu_delegate_options.has_cached_kernel_path() && + (options.has_model_path() || gpu_delegate_options.has_model_token()); use_serialized_model_ = gpu_delegate_options.has_serialized_model_dir() && gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { + std::string basename = options.has_model_path() + ? mediapipe::File::Basename(options.model_path()) + : gpu_delegate_options.model_token(); cached_kernel_filename_ = mediapipe::file::JoinPath( - gpu_delegate_options.cached_kernel_path(), - mediapipe::File::Basename(options.model_path()) + ".ker"); + gpu_delegate_options.cached_kernel_path(), basename + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = From fc526374abac9e1080e06470004ab292fe0c162a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 28 Nov 2022 17:48:37 -0800 Subject: [PATCH 137/469] Use GpuResources in GpuTestBase and update GpuBufferMultiPoolTest PiperOrigin-RevId: 491486495 --- mediapipe/gpu/gpu_test_base.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/gpu/gpu_test_base.h b/mediapipe/gpu/gpu_test_base.h index e9fd64725..6ec53603b 100644 --- a/mediapipe/gpu/gpu_test_base.h +++ b/mediapipe/gpu/gpu_test_base.h @@ -24,13 +24,14 @@ namespace mediapipe { class GpuTestBase : public ::testing::Test { protected: - GpuTestBase() { helper_.InitializeForTest(&gpu_shared_); } + GpuTestBase() { helper_.InitializeForTest(gpu_resources_.get()); } void RunInGlContext(std::function gl_func) { helper_.RunInGlContext(std::move(gl_func)); } GpuSharedData gpu_shared_; + std::shared_ptr gpu_resources_ = gpu_shared_.gpu_resources; GlCalculatorHelper helper_; }; From cc11b4522837ce2f3763831fca0447e3b7cef495 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 28 Nov 2022 17:52:35 -0800 Subject: [PATCH 138/469] Remove unneeded GPU_SHARED side packet for GlSurfaceSink PiperOrigin-RevId: 491487092 --- mediapipe/gpu/gl_surface_sink_calculator.cc | 1 - mediapipe/java/com/google/mediapipe/framework/jni/graph.cc | 2 -- 2 files changed, 3 deletions(-) diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index 31500ed9a..ad867c2be 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -37,7 +37,6 @@ enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; // VIDEO or index 0: GpuBuffers to be rendered. // Side inputs: // SURFACE: unique_ptr to an EglSurfaceHolder to draw to. -// GPU_SHARED: shared GPU resources. // // See GlSurfaceSinkCalculatorOptions for options. class GlSurfaceSinkCalculator : public Node { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 6a67c01cb..23bd553af 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -231,8 +231,6 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { *graph_config(), absl::StrCat("egl_surface_sink_", output_stream_name))); sink_node->set_calculator("GlSurfaceSinkCalculator"); sink_node->add_input_stream(output_stream_name); - sink_node->add_input_side_packet( - absl::StrCat(kGpuSharedTagName, ":", kGpuSharedSidePacketName)); const std::string input_side_packet_name = mediapipe::tool::GetUnusedSidePacketName( From c8a413bb4e5da6b977695987809a27b8f044f15a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 29 Nov 2022 10:17:21 -0800 Subject: [PATCH 139/469] Open up mediapipe framework's visibility. PiperOrigin-RevId: 491672877 --- mediapipe/calculators/image/BUILD | 41 +-------- mediapipe/calculators/tensorflow/BUILD | 70 +--------------- mediapipe/calculators/tflite/BUILD | 20 +---- mediapipe/calculators/util/BUILD | 83 ------------------- mediapipe/calculators/video/BUILD | 29 +------ mediapipe/examples/desktop/hello_world/BUILD | 3 +- mediapipe/framework/BUILD | 2 +- mediapipe/framework/formats/BUILD | 28 +------ mediapipe/framework/formats/annotation/BUILD | 4 +- mediapipe/framework/formats/motion/BUILD | 7 +- .../framework/formats/object_detection/BUILD | 4 +- mediapipe/framework/stream_handler/BUILD | 19 +---- .../holistic_landmark/calculators/BUILD | 3 - mediapipe/util/tracking/BUILD | 17 ---- 14 files changed, 11 insertions(+), 319 deletions(-) diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index c78bc5cf7..530dd3d4a 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -16,12 +16,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "opencv_image_encoder_calculator_proto", srcs = ["opencv_image_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -31,7 +30,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "scale_image_calculator_proto", srcs = ["scale_image_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "set_alpha_calculator_proto", srcs = ["set_alpha_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "image_cropping_calculator_proto", srcs = ["image_cropping_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "bilateral_filter_calculator_proto", srcs = ["bilateral_filter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "recolor_calculator_proto", srcs = ["recolor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "segmentation_smoothing_calculator_proto", srcs = ["segmentation_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -112,7 +104,6 @@ cc_library( cc_library( name = "opencv_encoded_image_to_image_frame_calculator", srcs = ["opencv_encoded_image_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_encoded_image_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -127,7 +118,6 @@ cc_library( cc_library( name = "opencv_image_encoder_calculator", srcs = ["opencv_image_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -142,7 +132,6 @@ cc_library( cc_library( name = "opencv_put_text_calculator", srcs = ["opencv_put_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame_opencv", @@ -156,7 +145,6 @@ cc_library( cc_library( name = "set_alpha_calculator", srcs = ["set_alpha_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -183,7 +171,6 @@ cc_library( cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -212,13 +199,11 @@ cc_library( mediapipe_proto_library( name = "rotation_mode_proto", srcs = ["rotation_mode.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ ":rotation_mode_proto", "//mediapipe/framework:calculator_options_proto", @@ -243,7 +228,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":rotation_mode_cc_proto", ":image_transformation_calculator_cc_proto", @@ -287,7 +271,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", @@ -330,7 +313,6 @@ cc_test( cc_library( name = "luminance_calculator", srcs = ["luminance_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -344,7 +326,6 @@ cc_library( cc_library( name = "sobel_edges_calculator", srcs = ["sobel_edges_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -358,7 +339,6 @@ cc_library( cc_library( name = "recolor_calculator", srcs = ["recolor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", "//mediapipe/util:color_cc_proto", @@ -385,9 +365,6 @@ cc_library( name = "scale_image_utils", srcs = ["scale_image_utils.cc"], hdrs = ["scale_image_utils.h"], - visibility = [ - "//mediapipe:__subpackages__", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -400,9 +377,6 @@ cc_library( cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":scale_image_utils", "//mediapipe/calculators/image:scale_image_calculator_cc_proto", @@ -429,7 +403,6 @@ cc_library( mediapipe_proto_library( name = "image_clone_calculator_proto", srcs = ["image_clone_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -439,7 +412,6 @@ mediapipe_proto_library( cc_library( name = "image_clone_calculator", srcs = ["image_clone_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_clone_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -459,7 +431,6 @@ cc_library( cc_library( name = "image_properties_calculator", srcs = ["image_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", @@ -524,7 +495,6 @@ cc_test( mediapipe_proto_library( name = "mask_overlay_calculator_proto", srcs = ["mask_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -534,7 +504,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "opencv_encoded_image_to_image_frame_calculator_proto", srcs = ["opencv_encoded_image_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -544,7 +513,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "feature_detector_calculator_proto", srcs = ["feature_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -554,7 +522,6 @@ mediapipe_proto_library( cc_library( name = "mask_overlay_calculator", srcs = ["mask_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":mask_overlay_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -570,7 +537,6 @@ cc_library( cc_library( name = "feature_detector_calculator", srcs = ["feature_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":feature_detector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -597,7 +563,6 @@ cc_library( cc_library( name = "image_file_properties_calculator", srcs = ["image_file_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_file_properties_cc_proto", @@ -627,7 +592,6 @@ cc_test( cc_library( name = "segmentation_smoothing_calculator", srcs = ["segmentation_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -724,7 +688,6 @@ cc_library( mediapipe_proto_library( name = "warp_affine_calculator_proto", srcs = ["warp_affine_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -736,7 +699,6 @@ cc_library( name = "warp_affine_calculator", srcs = ["warp_affine_calculator.cc"], hdrs = ["warp_affine_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":affine_transformation", ":warp_affine_calculator_cc_proto", @@ -817,7 +779,6 @@ cc_test( cc_library( name = "yuv_to_image_calculator", srcs = ["yuv_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 45f64f4f7..0f8f8706a 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "graph_tensors_packet_generator_proto", srcs = ["graph_tensors_packet_generator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework:packet_generator_proto", @@ -32,49 +31,42 @@ proto_library( proto_library( name = "matrix_to_tensor_calculator_options_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "lapped_tensor_buffer_calculator_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "object_detection_tensors_to_detections_calculator_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensorflow_inference_calculator_proto", srcs = ["tensorflow_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_squeeze_dimensions_calculator_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_image_frame_calculator_proto", srcs = ["tensor_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_matrix_calculator_proto", srcs = ["tensor_to_matrix_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:time_series_header_proto", @@ -84,30 +76,24 @@ proto_library( proto_library( name = "tensor_to_vector_float_calculator_options_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_int_calculator_options_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_string_calculator_options_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) mediapipe_proto_library( name = "unpack_media_sequence_calculator_proto", srcs = ["unpack_media_sequence_calculator.proto"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_proto", "//mediapipe/framework:calculator_proto", @@ -118,14 +104,12 @@ mediapipe_proto_library( proto_library( name = "vector_float_to_tensor_calculator_options_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "vector_string_to_tensor_calculator_options_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -136,7 +120,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":graph_tensors_packet_generator_proto"], ) @@ -147,7 +130,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":image_frame_to_tensor_calculator_proto"], ) @@ -155,7 +137,6 @@ mediapipe_cc_proto_library( name = "matrix_to_tensor_calculator_options_cc_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":matrix_to_tensor_calculator_options_proto"], ) @@ -163,7 +144,6 @@ mediapipe_cc_proto_library( name = "lapped_tensor_buffer_calculator_cc_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":lapped_tensor_buffer_calculator_proto"], ) @@ -171,7 +151,6 @@ mediapipe_cc_proto_library( name = "object_detection_tensors_to_detections_calculator_cc_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":object_detection_tensors_to_detections_calculator_proto"], ) @@ -179,7 +158,6 @@ mediapipe_cc_proto_library( name = "tensorflow_inference_calculator_cc_proto", srcs = ["tensorflow_inference_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensorflow_inference_calculator_proto"], ) @@ -190,7 +168,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_generator_proto"], ) @@ -201,7 +178,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_calculator_proto"], ) @@ -212,7 +188,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_generator_proto"], ) @@ -223,7 +198,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_calculator_proto"], ) @@ -231,7 +205,6 @@ mediapipe_cc_proto_library( name = "tensor_squeeze_dimensions_calculator_cc_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_squeeze_dimensions_calculator_proto"], ) @@ -239,7 +212,6 @@ mediapipe_cc_proto_library( name = "tensor_to_image_frame_calculator_cc_proto", srcs = ["tensor_to_image_frame_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_image_frame_calculator_proto"], ) @@ -250,7 +222,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tensor_to_matrix_calculator_proto"], ) @@ -258,7 +229,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_float_calculator_options_cc_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_float_calculator_options_proto"], ) @@ -266,7 +236,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_int_calculator_options_cc_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_int_calculator_options_proto"], ) @@ -274,7 +243,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_string_calculator_options_cc_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_string_calculator_options_proto"], ) @@ -285,7 +253,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":vector_int_to_tensor_calculator_options_proto"], ) @@ -293,7 +260,6 @@ mediapipe_cc_proto_library( name = "vector_float_to_tensor_calculator_options_cc_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_float_to_tensor_calculator_options_proto"], ) @@ -301,14 +267,12 @@ mediapipe_cc_proto_library( name = "vector_string_to_tensor_calculator_options_cc_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_string_to_tensor_calculator_options_proto"], ) cc_library( name = "graph_tensors_packet_generator", srcs = ["graph_tensors_packet_generator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_tensors_packet_generator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -323,7 +287,6 @@ cc_library( cc_library( name = "image_frame_to_tensor_calculator", srcs = ["image_frame_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_frame_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -344,7 +307,6 @@ cc_library( cc_library( name = "matrix_to_tensor_calculator", srcs = ["matrix_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":matrix_to_tensor_calculator_options_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", @@ -366,7 +328,6 @@ cc_library( cc_library( name = "lapped_tensor_buffer_calculator", srcs = ["lapped_tensor_buffer_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,9 +349,6 @@ cc_library( # Layering check doesn't play nicely with portable proto wrappers. "no_layering_check", ], - visibility = [ - "//visibility:public", - ], deps = [ ":object_detection_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,9 +365,6 @@ cc_library( cc_library( name = "pack_media_sequence_calculator", srcs = ["pack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", @@ -432,9 +387,6 @@ cc_library( cc_library( name = "string_to_sequence_example_calculator", srcs = ["string_to_sequence_example_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -449,7 +401,6 @@ cc_library( cc_library( name = "tensorflow_inference_calculator", srcs = ["tensorflow_inference_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", @@ -487,7 +438,6 @@ cc_library( "tensorflow_session.h", ], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:core", @@ -505,7 +455,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_calculator", srcs = ["tensorflow_session_from_frozen_graph_calculator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", @@ -537,7 +486,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_generator", srcs = ["tensorflow_session_from_frozen_graph_generator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_frozen_graph_generator_cc_proto", @@ -572,7 +520,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_calculator_cc_proto", @@ -611,7 +558,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_generator_cc_proto", @@ -637,7 +583,6 @@ cc_library( cc_library( name = "tensor_squeeze_dimensions_calculator", srcs = ["tensor_squeeze_dimensions_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_squeeze_dimensions_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -651,7 +596,6 @@ cc_library( cc_library( name = "tensor_to_image_frame_calculator", srcs = ["tensor_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -666,7 +610,6 @@ cc_library( cc_library( name = "tensor_to_matrix_calculator", srcs = ["tensor_to_matrix_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_matrix_calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", @@ -688,7 +631,6 @@ cc_library( cc_library( name = "tfrecord_reader_calculator", srcs = ["tfrecord_reader_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -704,7 +646,6 @@ cc_library( cc_library( name = "tensor_to_vector_float_calculator", srcs = ["tensor_to_vector_float_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -724,7 +665,6 @@ cc_library( cc_library( name = "tensor_to_vector_int_calculator", srcs = ["tensor_to_vector_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_int_calculator_options_cc_proto", "@com_google_absl//absl/base:core_headers", @@ -746,7 +686,6 @@ cc_library( cc_library( name = "tensor_to_vector_string_calculator", srcs = ["tensor_to_vector_string_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -766,9 +705,6 @@ cc_library( cc_library( name = "unpack_media_sequence_calculator", srcs = ["unpack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto", @@ -786,7 +722,6 @@ cc_library( cc_library( name = "vector_int_to_tensor_calculator", srcs = ["vector_int_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_int_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -800,7 +735,6 @@ cc_library( cc_library( name = "vector_float_to_tensor_calculator", srcs = ["vector_float_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_float_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -814,7 +748,6 @@ cc_library( cc_library( name = "vector_string_to_tensor_calculator", srcs = ["vector_string_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_string_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -828,7 +761,6 @@ cc_library( cc_library( name = "unpack_yt8m_sequence_example_calculator", srcs = ["unpack_yt8m_sequence_example_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 8edaeee02..db2a27630 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -18,12 +18,11 @@ load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "ssd_anchors_calculator_proto", srcs = ["ssd_anchors_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -33,7 +32,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_custom_op_resolver_calculator_proto", srcs = ["tflite_custom_op_resolver_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -43,7 +41,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_inference_calculator_proto", srcs = ["tflite_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -53,7 +50,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_converter_calculator_proto", srcs = ["tflite_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -63,7 +59,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_segmentation_calculator_proto", srcs = ["tflite_tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -73,7 +68,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_detections_calculator_proto", srcs = ["tflite_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_classification_calculator_proto", srcs = ["tflite_tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_landmarks_calculator_proto", srcs = ["tflite_tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -103,7 +95,6 @@ mediapipe_proto_library( cc_library( name = "ssd_anchors_calculator", srcs = ["ssd_anchors_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":ssd_anchors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -117,7 +108,6 @@ cc_library( cc_library( name = "tflite_custom_op_resolver_calculator", srcs = ["tflite_custom_op_resolver_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_custom_op_resolver_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -208,7 +198,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_inference_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -287,7 +276,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_converter_calculator_cc_proto", "//mediapipe/util/tflite:config", @@ -326,7 +314,6 @@ cc_library( cc_library( name = "tflite_model_calculator", srcs = ["tflite_model_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -340,7 +327,6 @@ cc_library( cc_library( name = "tflite_tensors_to_segmentation_calculator", srcs = ["tflite_tensors_to_segmentation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -408,7 +394,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -444,7 +429,6 @@ cc_library( cc_library( name = "tflite_tensors_to_classification_calculator", srcs = ["tflite_tensors_to_classification_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -476,7 +460,6 @@ cc_library( cc_library( name = "tflite_tensors_to_landmarks_calculator", srcs = ["tflite_tensors_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -490,7 +473,6 @@ cc_library( cc_library( name = "tflite_tensors_to_floats_calculator", srcs = ["tflite_tensors_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 24e976a73..43eadd53b 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "annotation_overlay_calculator_proto", srcs = ["annotation_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -50,7 +48,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detection_label_id_to_text_calculator_proto", srcs = ["detection_label_id_to_text_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -61,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "filter_detections_calculator_proto", srcs = ["filter_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -71,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -81,13 +76,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "latency_proto", srcs = ["latency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "non_max_suppression_calculator_proto", srcs = ["non_max_suppression_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -97,13 +90,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_frequency_proto", srcs = ["packet_frequency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "packet_frequency_calculator_proto", srcs = ["packet_frequency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -113,7 +104,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_latency_calculator_proto", srcs = ["packet_latency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -123,7 +113,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "collection_has_min_size_calculator_proto", srcs = ["collection_has_min_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -133,7 +122,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "association_calculator_proto", srcs = ["association_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -143,7 +131,6 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -188,7 +175,6 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", @@ -228,9 +214,6 @@ cc_test( cc_library( name = "clock_timestamp_calculator", srcs = ["clock_timestamp_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -246,9 +229,6 @@ cc_library( cc_library( name = "clock_latency_calculator", srcs = ["clock_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -263,7 +243,6 @@ cc_library( cc_library( name = "annotation_overlay_calculator", srcs = ["annotation_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -296,7 +275,6 @@ cc_library( cc_library( name = "detection_label_id_to_text_calculator", srcs = ["detection_label_id_to_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detection_label_id_to_text_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -328,7 +306,6 @@ cc_library( cc_library( name = "timed_box_list_id_to_label_calculator", srcs = ["timed_box_list_id_to_label_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_id_to_label_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -357,7 +334,6 @@ cc_library( cc_library( name = "detection_transformation_calculator", srcs = ["detection_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -391,7 +367,6 @@ cc_test( cc_library( name = "non_max_suppression_calculator", srcs = ["non_max_suppression_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":non_max_suppression_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -408,7 +383,6 @@ cc_library( cc_library( name = "thresholding_calculator", srcs = ["thresholding_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":thresholding_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -421,7 +395,6 @@ cc_library( cc_library( name = "detection_to_landmarks_calculator", srcs = ["detection_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -436,7 +409,6 @@ cc_library( cc_library( name = "filter_detections_calculator", srcs = ["filter_detections_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":filter_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -450,7 +422,6 @@ cc_library( cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_detection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -471,7 +442,6 @@ cc_library( hdrs = [ "detections_to_rects_calculator.h", ], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -489,7 +459,6 @@ cc_library( cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_transformation_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -504,7 +473,6 @@ cc_library( cc_library( name = "rect_projection_calculator", srcs = ["rect_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:rect_cc_proto", @@ -535,7 +503,6 @@ cc_test( mediapipe_proto_library( name = "rect_to_render_data_calculator_proto", srcs = ["rect_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -547,7 +514,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_to_render_scale_calculator_proto", srcs = ["rect_to_render_scale_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -557,7 +523,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_render_data_calculator_proto", srcs = ["detections_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -569,7 +534,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_render_data_calculator_proto", srcs = ["landmarks_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -581,7 +545,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_to_render_data_calculator_proto", srcs = ["timed_box_list_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -593,7 +556,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "labels_to_render_data_calculator_proto", srcs = ["labels_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -605,7 +567,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "thresholding_calculator_proto", srcs = ["thresholding_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -617,7 +578,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_rects_calculator_proto", srcs = ["detections_to_rects_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -627,7 +587,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmark_projection_calculator_proto", srcs = ["landmark_projection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -637,7 +596,6 @@ mediapipe_proto_library( cc_library( name = "landmark_visibility_calculator", srcs = ["landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -649,7 +607,6 @@ cc_library( cc_library( name = "set_landmark_visibility_calculator", srcs = ["set_landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -661,7 +618,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_floats_calculator_proto", srcs = ["landmarks_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -671,7 +627,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_transformation_calculator_proto", srcs = ["rect_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -681,7 +636,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_detection_calculator_proto", srcs = ["landmarks_to_detection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -693,7 +647,6 @@ mediapipe_proto_library( cc_library( name = "detections_to_render_data_calculator", srcs = ["detections_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -713,7 +666,6 @@ cc_library( name = "landmarks_to_render_data_calculator", srcs = ["landmarks_to_render_data_calculator.cc"], hdrs = ["landmarks_to_render_data_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -732,7 +684,6 @@ cc_library( cc_library( name = "timed_box_list_to_render_data_calculator", srcs = ["timed_box_list_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -751,7 +702,6 @@ cc_library( cc_library( name = "labels_to_render_data_calculator", srcs = ["labels_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -770,7 +720,6 @@ cc_library( cc_library( name = "rect_to_render_data_calculator", srcs = ["rect_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -785,7 +734,6 @@ cc_library( cc_library( name = "rect_to_render_scale_calculator", srcs = ["rect_to_render_scale_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_scale_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -820,7 +768,6 @@ cc_test( cc_library( name = "detection_letterbox_removal_calculator", srcs = ["detection_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -834,7 +781,6 @@ cc_library( cc_library( name = "detection_projection_calculator", srcs = ["detection_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -867,7 +813,6 @@ cc_test( cc_library( name = "landmark_letterbox_removal_calculator", srcs = ["landmark_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -881,7 +826,6 @@ cc_library( cc_library( name = "landmark_projection_calculator", srcs = ["landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmark_projection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -914,7 +858,6 @@ cc_test( cc_library( name = "world_landmark_projection_calculator", srcs = ["world_landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -928,7 +871,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_smoothing_calculator_proto", srcs = ["landmarks_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -938,7 +880,6 @@ mediapipe_proto_library( cc_library( name = "landmarks_smoothing_calculator", srcs = ["landmarks_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -956,7 +897,6 @@ cc_library( mediapipe_proto_library( name = "visibility_smoothing_calculator_proto", srcs = ["visibility_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -966,7 +906,6 @@ mediapipe_proto_library( cc_library( name = "visibility_smoothing_calculator", srcs = ["visibility_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -982,7 +921,6 @@ cc_library( mediapipe_proto_library( name = "visibility_copy_calculator_proto", srcs = ["visibility_copy_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -992,7 +930,6 @@ mediapipe_proto_library( cc_library( name = "visibility_copy_calculator", srcs = ["visibility_copy_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_copy_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1007,7 +944,6 @@ cc_library( cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1054,7 +990,6 @@ cc_test( mediapipe_proto_library( name = "top_k_scores_calculator_proto", srcs = ["top_k_scores_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1064,7 +999,6 @@ mediapipe_proto_library( cc_library( name = "top_k_scores_calculator", srcs = ["top_k_scores_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":top_k_scores_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1108,7 +1042,6 @@ cc_test( mediapipe_proto_library( name = "local_file_contents_calculator_proto", srcs = ["local_file_contents_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1118,7 +1051,6 @@ mediapipe_proto_library( cc_library( name = "local_file_contents_calculator", srcs = ["local_file_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":local_file_contents_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1132,7 +1064,6 @@ cc_library( cc_library( name = "local_file_pattern_contents_calculator", srcs = ["local_file_pattern_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:file_helpers", @@ -1146,7 +1077,6 @@ cc_library( name = "filter_collection_calculator", srcs = ["filter_collection_calculator.cc"], hdrs = ["filter_collection_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:classification_cc_proto", @@ -1164,7 +1094,6 @@ cc_library( name = "collection_has_min_size_calculator", srcs = ["collection_has_min_size_calculator.cc"], hdrs = ["collection_has_min_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":collection_has_min_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1192,7 +1121,6 @@ cc_test( cc_library( name = "association_calculator", hdrs = ["association_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":association_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1209,7 +1137,6 @@ cc_library( cc_library( name = "association_norm_rect_calculator", srcs = ["association_norm_rect_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1224,7 +1151,6 @@ cc_library( cc_library( name = "association_detection_calculator", srcs = ["association_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1259,7 +1185,6 @@ cc_test( cc_library( name = "detections_to_timed_box_list_calculator", srcs = ["detections_to_timed_box_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1274,7 +1199,6 @@ cc_library( cc_library( name = "detection_unique_id_calculator", srcs = ["detection_unique_id_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1287,7 +1211,6 @@ cc_library( mediapipe_proto_library( name = "logic_calculator_proto", srcs = ["logic_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1297,7 +1220,6 @@ mediapipe_proto_library( cc_library( name = "logic_calculator", srcs = ["logic_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":logic_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1310,7 +1232,6 @@ cc_library( cc_library( name = "to_image_calculator", srcs = ["to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -1333,7 +1254,6 @@ cc_library( cc_library( name = "from_image_calculator", srcs = ["from_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -1385,7 +1305,6 @@ cc_test( mediapipe_proto_library( name = "refine_landmarks_from_heatmap_calculator_proto", srcs = ["refine_landmarks_from_heatmap_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1403,7 +1322,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":refine_landmarks_from_heatmap_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1454,7 +1372,6 @@ cc_library( name = "inverse_matrix_calculator", srcs = ["inverse_matrix_calculator.cc"], hdrs = ["inverse_matrix_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 2db3ed252..f2b8135f2 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -21,19 +21,17 @@ load( licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "flow_to_image_calculator_proto", srcs = ["flow_to_image_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "opencv_video_encoder_calculator_proto", srcs = ["opencv_video_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -58,7 +56,6 @@ proto_library( proto_library( name = "box_tracker_calculator_proto", srcs = ["box_tracker_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_tracker_proto", @@ -68,7 +65,6 @@ proto_library( proto_library( name = "tracked_detection_manager_calculator_proto", srcs = ["tracked_detection_manager_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_proto", @@ -78,7 +74,6 @@ proto_library( proto_library( name = "box_detector_calculator_proto", srcs = ["box_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_detector_proto", @@ -88,7 +83,6 @@ proto_library( proto_library( name = "video_pre_stream_calculator_proto", srcs = ["video_pre_stream_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", ], @@ -101,7 +95,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:motion_analysis_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_calculator_proto"], ) @@ -112,7 +105,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:flow_packager_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_calculator_proto"], ) @@ -123,7 +115,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_tracker_calculator_proto"], ) @@ -134,7 +125,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_calculator_proto"], ) @@ -145,7 +135,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_detector_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_calculator_proto"], ) @@ -155,7 +144,6 @@ mediapipe_cc_proto_library( cc_deps = [ "//mediapipe/framework:calculator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":video_pre_stream_calculator_proto"], ) @@ -163,7 +151,6 @@ mediapipe_cc_proto_library( name = "flow_to_image_calculator_cc_proto", srcs = ["flow_to_image_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":flow_to_image_calculator_proto"], ) @@ -171,14 +158,12 @@ mediapipe_cc_proto_library( name = "opencv_video_encoder_calculator_cc_proto", srcs = ["opencv_video_encoder_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":opencv_video_encoder_calculator_proto"], ) cc_library( name = "flow_to_image_calculator", srcs = ["flow_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_to_image_calculator_cc_proto", "//mediapipe/calculators/video/tool:flow_quantizer_model", @@ -198,7 +183,6 @@ cc_library( cc_library( name = "opencv_video_decoder_calculator", srcs = ["opencv_video_decoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", @@ -217,7 +201,6 @@ cc_library( cc_library( name = "opencv_video_encoder_calculator", srcs = ["opencv_video_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_video_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -240,7 +223,6 @@ cc_library( cc_library( name = "tvl1_optical_flow_calculator", srcs = ["tvl1_optical_flow_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -256,7 +238,6 @@ cc_library( cc_library( name = "motion_analysis_calculator", srcs = ["motion_analysis_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":motion_analysis_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -282,7 +263,6 @@ cc_library( cc_library( name = "flow_packager_calculator", srcs = ["flow_packager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_packager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -300,7 +280,6 @@ cc_library( cc_library( name = "box_tracker_calculator", srcs = ["box_tracker_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -327,7 +306,6 @@ cc_library( cc_library( name = "box_detector_calculator", srcs = ["box_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_detector_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -369,7 +347,6 @@ cc_library( cc_library( name = "tracked_detection_manager_calculator", srcs = ["tracked_detection_manager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tracked_detection_manager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -390,7 +367,6 @@ cc_library( cc_library( name = "video_pre_stream_calculator", srcs = ["video_pre_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":video_pre_stream_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,7 +383,6 @@ filegroup( "testdata/format_MKV_VP8_VORBIS.video", "testdata/format_MP4_AVC720P_AAC.video", ], - visibility = ["//visibility:public"], ) cc_test( @@ -480,7 +455,6 @@ mediapipe_binary_graph( name = "parallel_tracker_binarypb", graph = "testdata/parallel_tracker_graph.pbtxt", output_name = "testdata/parallel_tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", @@ -494,7 +468,6 @@ mediapipe_binary_graph( name = "tracker_binarypb", graph = "testdata/tracker_graph.pbtxt", output_name = "testdata/tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", diff --git a/mediapipe/examples/desktop/hello_world/BUILD b/mediapipe/examples/desktop/hello_world/BUILD index edf98bf13..27aa088e7 100644 --- a/mediapipe/examples/desktop/hello_world/BUILD +++ b/mediapipe/examples/desktop/hello_world/BUILD @@ -14,12 +14,11 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/examples:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) cc_binary( name = "hello_world", srcs = ["hello_world.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_graph", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index e3429f1e9..3cc72b4f1 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -139,7 +139,7 @@ mediapipe_proto_library( name = "test_calculators_proto", testonly = 1, srcs = ["test_calculators.proto"], - visibility = ["//visibility:public"], + visibility = [":mediapipe_internal"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 4276ffc3a..fdb698c48 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -17,7 +17,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) @@ -26,7 +26,6 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats:location_data_proto"], ) @@ -45,7 +44,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "classification_proto", srcs = ["classification.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -64,46 +62,39 @@ mediapipe_register_type( mediapipe_proto_library( name = "image_format_proto", srcs = ["image_format.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "matrix_data_proto", srcs = ["matrix_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "location_data_proto", srcs = ["location_data.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "affine_transform_data_proto", srcs = ["affine_transform_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "time_series_header_proto", srcs = ["time_series_header.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_file_properties_proto", srcs = ["image_file_properties.proto"], - visibility = ["//visibility:public"], ) cc_library( name = "deleting_file", srcs = ["deleting_file.cc"], hdrs = ["deleting_file.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", ], @@ -113,7 +104,6 @@ cc_library( name = "matrix", srcs = ["matrix.cc"], hdrs = ["matrix.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/formats:matrix_data_cc_proto", @@ -129,9 +119,6 @@ cc_library( name = "affine_transform", srcs = ["affine_transform.cc"], hdrs = ["affine_transform.h"], - visibility = [ - "//visibility:public", - ], deps = [ ":affine_transform_data_cc_proto", "//mediapipe/framework:port", @@ -154,7 +141,6 @@ cc_library( name = "image_frame", srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/base", @@ -179,7 +165,6 @@ cc_library( name = "image_frame_opencv", srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "//mediapipe/framework/formats:image_format_cc_proto", @@ -206,7 +191,6 @@ cc_library( name = "location", srcs = ["location.cc"], hdrs = ["location.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", @@ -238,7 +222,6 @@ cc_library( name = "location_opencv", srcs = ["location_opencv.cc"], hdrs = ["location_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":location", "//mediapipe/framework/formats/annotation:rasterization_cc_proto", @@ -261,7 +244,6 @@ cc_test( cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", ], @@ -270,7 +252,6 @@ cc_library( cc_library( name = "yuv_image", hdrs = ["yuv_image.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:integral_types", "@libyuv", @@ -294,7 +275,6 @@ cc_test( mediapipe_proto_library( name = "rect_proto", srcs = ["rect.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -312,7 +292,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -344,7 +323,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -374,7 +352,6 @@ cc_library( name = "image_multi_pool", srcs = ["image_multi_pool.cc"], hdrs = ["image_multi_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_frame_pool", @@ -411,7 +388,6 @@ cc_library( hdrs = [ "image_opencv.h", ], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_format_cc_proto", @@ -425,7 +401,6 @@ cc_library( name = "image_frame_pool", srcs = ["image_frame_pool.cc"], hdrs = ["image_frame_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "@com_google_absl//absl/memory", @@ -476,7 +451,6 @@ cc_library( "-landroid", ], }), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/framework/formats/annotation/BUILD b/mediapipe/framework/formats/annotation/BUILD index 328001e85..9bcb7bccd 100644 --- a/mediapipe/framework/formats/annotation/BUILD +++ b/mediapipe/framework/formats/annotation/BUILD @@ -16,7 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -24,12 +24,10 @@ mediapipe_proto_library( name = "locus_proto", srcs = ["locus.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "rasterization_proto", srcs = ["rasterization.proto"], - visibility = ["//visibility:public"], ) diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 9819d262c..f1bbc0289 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -20,18 +20,16 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "optical_flow_field_data_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "optical_flow_field_data_cc_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], deps = [":optical_flow_field_data_proto"], ) @@ -39,9 +37,6 @@ cc_library( name = "optical_flow_field", srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", diff --git a/mediapipe/framework/formats/object_detection/BUILD b/mediapipe/framework/formats/object_detection/BUILD index 39940acdc..35292e1cc 100644 --- a/mediapipe/framework/formats/object_detection/BUILD +++ b/mediapipe/framework/formats/object_detection/BUILD @@ -19,17 +19,15 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "anchor_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "anchor_cc_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], deps = [":anchor_proto"], ) diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 866a5120e..01ef6ee86 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -18,35 +18,31 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "fixed_size_input_stream_handler_proto", srcs = ["fixed_size_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "sync_set_input_stream_handler_proto", srcs = ["sync_set_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "timestamp_align_input_stream_handler_proto", srcs = ["timestamp_align_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -54,7 +50,6 @@ mediapipe_cc_proto_library( name = "default_input_stream_handler_cc_proto", srcs = ["default_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":default_input_stream_handler_proto"], ) @@ -62,7 +57,6 @@ mediapipe_cc_proto_library( name = "fixed_size_input_stream_handler_cc_proto", srcs = ["fixed_size_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":fixed_size_input_stream_handler_proto"], ) @@ -70,7 +64,6 @@ mediapipe_cc_proto_library( name = "sync_set_input_stream_handler_cc_proto", srcs = ["sync_set_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":sync_set_input_stream_handler_proto"], ) @@ -78,14 +71,12 @@ mediapipe_cc_proto_library( name = "timestamp_align_input_stream_handler_cc_proto", srcs = ["timestamp_align_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":timestamp_align_input_stream_handler_proto"], ) cc_library( name = "barrier_input_stream_handler", srcs = ["barrier_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -96,7 +87,6 @@ cc_library( name = "default_input_stream_handler", srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", @@ -108,7 +98,6 @@ cc_library( cc_library( name = "early_close_input_stream_handler", srcs = ["early_close_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "@com_google_absl//absl/strings", @@ -119,7 +108,6 @@ cc_library( cc_library( name = "fixed_size_input_stream_handler", srcs = ["fixed_size_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ ":default_input_stream_handler", "//mediapipe/framework:input_stream_handler", @@ -131,7 +119,6 @@ cc_library( cc_library( name = "immediate_input_stream_handler", srcs = ["immediate_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -142,7 +129,6 @@ cc_library( name = "in_order_output_stream_handler", srcs = ["in_order_output_stream_handler.cc"], hdrs = ["in_order_output_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -160,7 +146,6 @@ cc_library( cc_library( name = "mux_input_stream_handler", srcs = ["mux_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/port:logging", @@ -173,7 +158,6 @@ cc_library( cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -192,7 +176,6 @@ cc_library( cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", diff --git a/mediapipe/modules/holistic_landmark/calculators/BUILD b/mediapipe/modules/holistic_landmark/calculators/BUILD index c3c091924..bc00b697c 100644 --- a/mediapipe/modules/holistic_landmark/calculators/BUILD +++ b/mediapipe/modules/holistic_landmark/calculators/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "hand_detections_from_pose_to_rects_calculator", srcs = ["hand_detections_from_pose_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "roi_tracking_calculator_proto", srcs = ["roi_tracking_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -49,7 +47,6 @@ mediapipe_proto_library( cc_library( name = "roi_tracking_calculator", srcs = ["roi_tracking_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":roi_tracking_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 3f1ebb353..6bca24446 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -134,7 +134,6 @@ proto_library( mediapipe_cc_proto_library( name = "tone_models_cc_proto", srcs = ["tone_models.proto"], - visibility = ["//visibility:public"], deps = [":tone_models_proto"], ) @@ -142,7 +141,6 @@ mediapipe_cc_proto_library( name = "tone_estimation_cc_proto", srcs = ["tone_estimation.proto"], cc_deps = [":tone_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tone_estimation_proto"], ) @@ -153,21 +151,18 @@ mediapipe_cc_proto_library( ":tone_estimation_cc_proto", ":tone_models_cc_proto", ], - visibility = ["//visibility:public"], deps = [":region_flow_computation_proto"], ) mediapipe_cc_proto_library( name = "motion_saliency_cc_proto", srcs = ["motion_saliency.proto"], - visibility = ["//visibility:public"], deps = [":motion_saliency_proto"], ) mediapipe_cc_proto_library( name = "motion_estimation_cc_proto", srcs = ["motion_estimation.proto"], - visibility = ["//visibility:public"], deps = [":motion_estimation_proto"], ) @@ -179,7 +174,6 @@ mediapipe_cc_proto_library( ":motion_saliency_cc_proto", ":region_flow_computation_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_proto"], ) @@ -187,14 +181,12 @@ mediapipe_cc_proto_library( name = "region_flow_cc_proto", srcs = ["region_flow.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":region_flow_proto"], ) mediapipe_cc_proto_library( name = "motion_models_cc_proto", srcs = ["motion_models.proto"], - visibility = ["//visibility:public"], deps = [":motion_models_proto"], ) @@ -202,21 +194,18 @@ mediapipe_cc_proto_library( name = "camera_motion_cc_proto", srcs = ["camera_motion.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":camera_motion_proto"], ) mediapipe_cc_proto_library( name = "push_pull_filtering_cc_proto", srcs = ["push_pull_filtering.proto"], - visibility = ["//visibility:public"], deps = [":push_pull_filtering_proto"], ) mediapipe_cc_proto_library( name = "frame_selection_solution_evaluator_cc_proto", srcs = ["frame_selection_solution_evaluator.proto"], - visibility = ["//visibility:public"], deps = [":frame_selection_solution_evaluator_proto"], ) @@ -228,7 +217,6 @@ mediapipe_cc_proto_library( ":frame_selection_solution_evaluator_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":frame_selection_proto"], ) @@ -239,7 +227,6 @@ mediapipe_cc_proto_library( ":motion_models_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_proto"], ) @@ -247,7 +234,6 @@ mediapipe_cc_proto_library( name = "tracking_cc_proto", srcs = ["tracking.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tracking_proto"], ) @@ -255,14 +241,12 @@ mediapipe_cc_proto_library( name = "box_tracker_cc_proto", srcs = ["box_tracker.proto"], cc_deps = [":tracking_cc_proto"], - visibility = ["//visibility:public"], deps = [":box_tracker_proto"], ) mediapipe_cc_proto_library( name = "tracked_detection_manager_config_cc_proto", srcs = ["tracked_detection_manager_config.proto"], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_config_proto"], ) @@ -273,7 +257,6 @@ mediapipe_cc_proto_library( ":box_tracker_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_proto"], ) From 09740130e874560957b154bbb51ae4c90dcd64ca Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 29 Nov 2022 11:32:44 -0800 Subject: [PATCH 140/469] Use naturalWidth and naturalHeight for image data PiperOrigin-RevId: 491694147 --- mediapipe/web/graph_runner/graph_runner.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 9a0f7148c..9a8101659 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -325,6 +325,10 @@ export class GraphRunner { if ((imageSource as HTMLVideoElement).videoWidth) { width = (imageSource as HTMLVideoElement).videoWidth; height = (imageSource as HTMLVideoElement).videoHeight; + } else if ((imageSource as HTMLImageElement).naturalWidth) { + // TODO: Ensure this works with SVG images + width = (imageSource as HTMLImageElement).naturalWidth; + height = (imageSource as HTMLImageElement).naturalHeight; } else { width = imageSource.width; height = imageSource.height; From 88173948eed970b3cc5c215ec3541fcc08b1723c Mon Sep 17 00:00:00 2001 From: Michael Hays Date: Tue, 29 Nov 2022 13:37:18 -0800 Subject: [PATCH 141/469] Internal change PiperOrigin-RevId: 491724816 --- mediapipe/web/graph_runner/graph_runner.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 9a8101659..a9bb979af 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -1085,8 +1085,8 @@ async function runScript(scriptUrl: string) { */ export async function createMediaPipeLib( constructorFcn: WasmMediaPipeConstructor, - wasmLoaderScript?: string, - assetLoaderScript?: string, + wasmLoaderScript?: string|null, + assetLoaderScript?: string|null, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, fileLocator?: FileLocator): Promise { const scripts = []; From fcd2d2c5af18dc4ebf16116a4f472b4bdb5e52a0 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 29 Nov 2022 14:12:14 -0800 Subject: [PATCH 142/469] Internal change PiperOrigin-RevId: 491733850 --- mediapipe/gpu/BUILD | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9cc670fb6..7a8aa6557 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -176,6 +176,16 @@ cc_library( "-fobjc-arc", # enable reference-counting ], }), + linkopts = select({ + "//conditions:default": [], + "//mediapipe:ios": [ + "-framework OpenGLES", + ], + "//mediapipe:macos": [ + "-framework OpenGL", + "-framework AppKit", + ], + }), visibility = ["//visibility:public"], deps = [ ":attachments", @@ -204,8 +214,10 @@ cc_library( }) + select({ "//conditions:default": [ ], - "//mediapipe:ios": [], - "//mediapipe:macos": [], + "//mediapipe:ios": [ + ], + "//mediapipe:macos": [ + ], }), ) From 460aee7933f255c749bda69673174ec91a9be017 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 30 Nov 2022 20:40:00 -0800 Subject: [PATCH 143/469] Make mediapipe_tasks_aar's android_library depend on "//third_party:androidx_annotation". PiperOrigin-RevId: 492092487 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 762184842..6ca67c096 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -289,6 +289,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:androidx_annotation", "//third_party:autovalue", "@maven//:com_google_guava_guava", ] + select({ From 29c7702984fd0309fbadf64347fdd7cb5604b52f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 05:50:46 -0800 Subject: [PATCH 144/469] Inline formerly nested 'ClassifierOptions' in Java classifier APIs. PiperOrigin-RevId: 492173060 --- .../com/google/mediapipe/tasks/audio/BUILD | 2 +- .../audioclassifier/AudioClassifier.java | 84 ++++++++++++++--- .../com/google/mediapipe/tasks/text/BUILD | 2 +- .../text/textclassifier/TextClassifier.java | 90 ++++++++++++++++--- .../com/google/mediapipe/tasks/vision/BUILD | 2 +- .../imageclassifier/ImageClassifier.java | 82 ++++++++++++++--- .../textclassifier/TextClassifierTest.java | 31 +++++++ .../imageclassifier/ImageClassifierTest.java | 81 +++++++++++------ 8 files changed, 305 insertions(+), 69 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 6771335ad..2afc75ec0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -66,10 +66,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index 0f3374175..d78685fe3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -266,7 +266,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /* * Sends audio data (a block in a continuous audio stream) to perform audio classification, and - * the results will be available via the {@link ResultListener} provided in the + * the results will be available via the {@link ResultListener} provided in the * {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with * the audio stream mode. * @@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /** * Validates and builds the {@link AudioClassifierOptions} instance. * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the audio classifier - * is in the audio stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final AudioClassifierOptions build() { AudioClassifierOptions options = autoBuild(); @@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi { "The audio classifier is in the audio clips mode, a user-defined result listener" + " shouldn't be provided in AudioClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 023a1f286..f9c8e7c76 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -49,10 +49,10 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 341d6bf91..0ea91a9f8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.TaskInfo; @@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); - public abstract TextClassifierOptions build(); + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); + + abstract TextClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link TextClassifierOptions} instance. + * + * @throws IllegalArgumentException if any of the set options are invalid. + */ + public final TextClassifierOptions build() { + TextClassifierOptions options = autoBuild(); + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } + return options; + } } abstract BaseOptions baseOptions(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); public static Builder builder() { - return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); + return new AutoValue_TextClassifier_TextClassifierOptions.Builder() + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index b7febb118..2d130ff05 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -98,10 +98,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 5e278804b..8990f46fd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { /** * Validates and builds the {@link ImageClassifierOptions} instance. * * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the image classifier - * is in the live stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final ImageClassifierOptions build() { ImageClassifierOptions options = autoBuild(); @@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi { "The image classifier is in the image or video mode, a user-defined result listener" + " shouldn't be provided in ImageClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index 5e03d2a4c..5ed413f6a 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -40,6 +40,37 @@ public class TextClassifierTest { private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 69820ce2d..dac11bf02 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -55,6 +54,37 @@ public class ImageClassifierTest { @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; @@ -105,7 +135,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -125,7 +155,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -141,7 +171,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .setScoreThreshold(0.02f) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -160,10 +190,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) - .build()) + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -183,11 +210,8 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setMaxResults(3) - .setCategoryDenylist(Arrays.asList("bagel")) - .build()) + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -207,7 +231,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -228,7 +252,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -251,7 +275,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -322,14 +346,14 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -353,7 +377,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -379,7 +403,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -388,7 +412,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -405,13 +429,14 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.VIDEO) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + ImageClassifierResult results = + imageClassifier.classifyForVideo(image, /* timestampMs= */ i); assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -424,7 +449,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -436,11 +461,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); + () -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -453,7 +478,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -466,7 +491,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, /*timestampMs=*/ i); + imageClassifier.classifyAsync(image, /* timestampMs= */ i); } } } From 01010fa24887e50f1bb851e9758847f6f340bea3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 07:15:52 -0800 Subject: [PATCH 145/469] Internal change PiperOrigin-RevId: 492188196 --- .../com/google/mediapipe/tasks/audio/BUILD | 2 +- .../audio/audioembedder/AudioEmbedder.java | 40 ++++++++--- .../tasks/components/processors/BUILD | 13 ---- .../processors/EmbedderOptions.java | 68 ------------------ .../com/google/mediapipe/tasks/text/BUILD | 2 +- .../tasks/text/textembedder/TextEmbedder.java | 41 ++++++++--- .../com/google/mediapipe/tasks/vision/BUILD | 2 +- .../vision/imageembedder/ImageEmbedder.java | 40 ++++++++--- .../imageembedder/ImageEmbedderTest.java | 69 +++++++++---------- 9 files changed, 126 insertions(+), 151 deletions(-) delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 2afc75ec0..2d29ccf23 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -92,12 +92,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java index c0bc04a4e..4bc505d84 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -28,7 +28,7 @@ import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -309,10 +309,24 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score - * threshold, number of results, etc. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); /** * Sets the {@link ResultListener} to receive the embedding results asynchronously when the @@ -354,7 +368,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); abstract Optional> resultListener(); @@ -362,7 +378,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -372,12 +390,14 @@ public final class AudioEmbedder extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder = AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index e61e59390..1f99f1612 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -29,19 +29,6 @@ android_library( ], ) -android_library( - name = "embedderoptions", - srcs = ["EmbedderOptions.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) - # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java deleted file mode 100644 index 3cd197234..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.components.processors; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; - -/** Embedder options shared across MediaPipe Java embedding tasks. */ -@AutoValue -public abstract class EmbedderOptions { - - /** Builder for {@link EmbedderOptions} */ - @AutoValue.Builder - public abstract static class Builder { - /** - * Sets whether L2 normalization should be performed on the returned embeddings. Use this option - * only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. - * In most cases, this is already the case and L2 norm is thus achieved through TF Lite - * inference. - * - *

False by default. - */ - public abstract Builder setL2Normalize(boolean l2Normalize); - - /** - * Sets whether the returned embedding should be quantized to bytes via scalar quantization. - * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed - * to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} if this is - * not the case. - * - *

False by default. - */ - public abstract Builder setQuantize(boolean quantize); - - public abstract EmbedderOptions build(); - } - - public abstract boolean l2Normalize(); - - public abstract boolean quantize(); - - public static Builder builder() { - return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false); - } - - /** - * Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions} - * protobuf message. - */ - public EmbedderOptionsProto.EmbedderOptions convertToProto() { - return EmbedderOptionsProto.EmbedderOptions.newBuilder() - .setL2Normalize(l2Normalize()) - .setQuantize(quantize()) - .build(); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index f9c8e7c76..5b10e9aab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -74,11 +74,11 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java index 95fa1f087..9b464d0e8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -25,7 +25,7 @@ import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; @@ -41,7 +41,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; /** * Performs embedding extraction on text. @@ -218,20 +217,38 @@ public final class TextEmbedder implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as - * L2-normalization and scalar quantization. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); public abstract TextEmbedderOptions build(); } abstract BaseOptions baseOptions(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); public static Builder builder() { - return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder(); + return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder() + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -240,12 +257,14 @@ public final class TextEmbedder implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 2d130ff05..b61c174fe 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -190,11 +190,11 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java index 0d8ecd5c3..af053d860 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -369,10 +369,24 @@ public final class ImageEmbedder extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as - * L2-normalization and scalar quantization. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); /** * Sets the {@link ResultListener} to receive the embedding results asynchronously when the @@ -414,7 +428,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); abstract Optional> resultListener(); @@ -422,7 +438,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -432,12 +450,14 @@ public final class ImageEmbedder extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java index 56249ead9..8dec6f80b 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -25,7 +25,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -92,8 +91,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -105,12 +104,8 @@ public class ImageEmbedderTest { @Test public void embed_succeedsWithL2Normalization() throws Exception { BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); - EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build(); ImageEmbedderOptions options = - ImageEmbedderOptions.builder() - .setBaseOptions(baseOptions) - .setEmbedderOptions(embedderOptions) - .build(); + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setL2Normalize(true).build(); ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -118,8 +113,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -131,12 +126,8 @@ public class ImageEmbedderTest { @Test public void embed_succeedsWithQuantization() throws Exception { BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); - EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build(); ImageEmbedderOptions options = - ImageEmbedderOptions.builder() - .setBaseOptions(baseOptions) - .setEmbedderOptions(embedderOptions) - .build(); + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setQuantize(true).build(); ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -144,8 +135,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ true); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -168,8 +159,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRoi, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -190,8 +181,8 @@ public class ImageEmbedderTest { imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -214,8 +205,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -277,12 +268,14 @@ public class ImageEmbedderTest { assertThrows( MediaPipeException.class, () -> - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -303,7 +296,8 @@ public class ImageEmbedderTest { exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -327,7 +321,8 @@ public class ImageEmbedderTest { assertThrows( MediaPipeException.class, () -> - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -340,8 +335,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -363,8 +358,8 @@ public class ImageEmbedderTest { for (int i = 0; i < 3; ++i) { ImageEmbedderResult result = - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i); - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ i); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); } } @@ -378,17 +373,18 @@ public class ImageEmbedderTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageEmbedderResult, inputImage) -> { - assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); assertImageSizeIsExpected(inputImage); }) .build(); try (ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0)); + () -> imageEmbedder.embedAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -405,14 +401,15 @@ public class ImageEmbedderTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageEmbedderResult, inputImage) -> { - assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); assertImageSizeIsExpected(inputImage); }) .build(); try (ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageEmbedder.embedAsync(image, /*timestampMs=*/ i); + imageEmbedder.embedAsync(image, /* timestampMs= */ i); } } } From a430939fe4b333ddb31a254f6a08b072f7dfff57 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 1 Dec 2022 07:42:55 -0800 Subject: [PATCH 146/469] Document RunningMode PiperOrigin-RevId: 492193299 --- .../vision/gesture_recognizer/gesture_recognizer.ts | 8 ++++++-- .../web/vision/hand_landmarker/hand_landmarker.ts | 8 ++++++-- .../web/vision/image_classifier/image_classifier.ts | 6 ++++-- .../tasks/web/vision/image_embedder/image_embedder.ts | 11 ++++------- .../web/vision/object_detector/object_detector.ts | 8 ++++++-- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 7441911c1..9ec63b07a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -225,7 +225,9 @@ export class GestureRecognizer extends /** * Performs gesture recognition on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `image`. + * * @param image A single image to process. * @return The detected gestures. */ @@ -235,7 +237,9 @@ export class GestureRecognizer extends /** * Performs gesture recognition on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The detected gestures. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 6d69d568c..290f49455 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -177,7 +177,9 @@ export class HandLandmarker extends VisionTaskRunner { /** * Performs hand landmarks detection on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `image`. + * * @param image An image to process. * @return The detected hand landmarks. */ @@ -187,7 +189,9 @@ export class HandLandmarker extends VisionTaskRunner { /** * Performs hand landmarks detection on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The detected hand landmarks. diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 604795f9f..185ddf9ea 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -120,7 +120,8 @@ export class ImageClassifier extends VisionTaskRunner { /** * Performs image classification on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `image`. * * @param image An image to process. * @return The classification result of the image @@ -131,7 +132,8 @@ export class ImageClassifier extends VisionTaskRunner { /** * Performs image classification on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `video`. * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 68068db6d..91352e934 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -122,10 +122,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Performs embedding extraction on the provided single image and waits - * synchronously for the response. - * - * Only use this method when the `useStreamMode` option is not set or - * expliclity set to `false`. + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `image`. * * @param image The image to process. * @return The classification result of the image @@ -136,9 +134,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Performs embedding extraction on the provided video frame and waits - * synchronously for the response. - * - * Only use this method when the `useStreamMode` option is set to `true`. + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `video`. * * @param imageFrame The image frame to process. * @param timestamp The timestamp of the current frame, in ms. diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 0f039acb2..7711c39e9 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -151,7 +151,9 @@ export class ObjectDetector extends VisionTaskRunner { /** * Performs object detection on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `image`. + * * @param image An image to process. * @return The list of detected objects */ @@ -161,7 +163,9 @@ export class ObjectDetector extends VisionTaskRunner { /** * Performs object detection on the provided vidoe frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The list of detected objects From e7eee27c1c78649e126d197ec338b779ff72d356 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 08:14:53 -0800 Subject: [PATCH 147/469] Remove the deleted library "mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions" from mediapipe_tasks_aar's android_library deps list. PiperOrigin-RevId: 492200061 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 6ca67c096..d91c03cc2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -286,7 +286,6 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:androidx_annotation", From 3ee37800e2d63092d8f8ded69619380eb55ad9ea Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 08:41:33 -0800 Subject: [PATCH 148/469] Depending on "inference_calculator_cpu" when the mediapipe tasks can only support cpu inference. PiperOrigin-RevId: 492205954 --- mediapipe/tasks/cc/audio/audio_classifier/BUILD | 2 +- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 2 +- mediapipe/tasks/cc/text/text_classifier/BUILD | 2 +- mediapipe/tasks/cc/text/text_embedder/BUILD | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index a817bcc3b..f61472413 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -55,7 +55,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index adba28e6a..6a0f627b2 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -56,7 +56,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 61395cf4e..3c9c3fc0e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -47,7 +47,7 @@ cc_library( name = "text_classifier_graph", srcs = ["text_classifier_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index f19af35be..4c970159e 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -48,8 +48,8 @@ cc_library( name = "text_embedder_graph", srcs = ["text_embedder_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", From e685ac93446e22d31a6bc269416ff13dece6edbe Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 1 Dec 2022 08:45:47 -0800 Subject: [PATCH 149/469] Re-use classifier options for ObjectDetector PiperOrigin-RevId: 492206856 --- .../web/components/utils/cosine_similarity.ts | 1 + .../tasks/web/vision/object_detector/BUILD | 1 + .../object_detector_options.d.ts | 33 ++----------------- 3 files changed, 5 insertions(+), 30 deletions(-) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts index fb1d0c185..1f483b9b6 100644 --- a/mediapipe/tasks/web/components/utils/cosine_similarity.ts +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -36,6 +36,7 @@ export function computeCosineSimilarity(u: Embedding, v: Embedding): number { throw new Error( 'Cannot compute cosine similarity between quantized and float embeddings.'); } + function convertToBytes(data: Uint8Array): number[] { return Array.from(data, v => v - 128); } diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index b6bef6bfa..198585258 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -35,6 +35,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts index 1d20ce1e2..7564e7760 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -14,36 +14,9 @@ * limitations under the License. */ +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions extends VisionTaskOptions { - /** - * The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. - */ - displayNamesLocale?: string|undefined; - - /** The maximum number of top-scored detection results to return. */ - maxResults?: number|undefined; - - /** - * Overrides the value provided in the model metadata. Results below this - * value are rejected. - */ - scoreThreshold?: number|undefined; - - /** - * Allowlist of category names. If non-empty, detection results whose category - * name is not in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryDenylist`. - */ - categoryAllowlist?: string[]|undefined; - - /** - * Denylist of category names. If non-empty, detection results whose category - * name is in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryAllowlist`. - */ - categoryDenylist?: string[]|undefined; -} +export interface ObjectDetectorOptions extends VisionTaskOptions, + ClassifierOptions {} From 02aa162c9e953b05153f68d13e55a06b34571a0f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 11:09:26 -0800 Subject: [PATCH 150/469] Rename gesture_recognizer test_data to testdata to be consistent with rest of model_maker PiperOrigin-RevId: 492246728 --- .../python/vision/gesture_recognizer/BUILD | 12 ++++++------ .../gesture_recognizer/gesture_recognizer_demo.py | 2 +- .../gesture_recognizer/gesture_recognizer_test.py | 2 +- .../gesture_recognizer/metadata_writer_test.py | 2 +- .../metadata/custom_gesture_classifier.tflite | Bin .../metadata/custom_gesture_classifier_meta.json | 0 .../call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg | Bin .../call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg | Bin .../call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg | Bin .../call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg | Bin .../call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg | Bin .../call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg | Bin .../call/17d804b5-7118-462d-8191-58d764f591b8.jpg | Bin .../call/1d65a858-623a-4984-9420-958c7e870c3e.jpg | Bin .../call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg | Bin .../call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg | Bin .../four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg | Bin .../four/077fa4bf-a99e-496b-b895-709afc614eec.jpg | Bin .../four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg | Bin .../four/07fdea90-1102-4419-a3af-b394cb29531b.jpg | Bin .../four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg | Bin .../four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg | Bin .../four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg | Bin .../four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg | Bin .../four/249c5023-6106-447a-84ac-17eb4713731b.jpg | Bin .../four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg | Bin .../none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg | Bin .../none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg | Bin .../none/00c84257-800d-4032-9e64-e47eb97005f5.jpg | Bin .../none/0a038096-c14f-46ac-9155-980161ebc440.jpg | Bin .../none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg | Bin .../none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg | Bin .../none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg | Bin .../none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg | Bin .../none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg | Bin .../none/0a787971-9377-4888-803f-aef21863ef7d.jpg | Bin .../rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg | Bin .../rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg | Bin .../rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg | Bin .../rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg | Bin .../rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg | Bin .../rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg | Bin .../rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg | Bin .../rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg | Bin .../rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg | Bin .../rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg | Bin 46 files changed, 9 insertions(+), 9 deletions(-) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/metadata/custom_gesture_classifier.tflite (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/metadata/custom_gesture_classifier_meta.json (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg (100%) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b9425a181..256447a8d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -24,9 +24,9 @@ package( # TODO: Remove the unncessary test data once the demo data are moved to an open-sourced # directory. filegroup( - name = "test_data", + name = "testdata", srcs = glob([ - "test_data/**", + "testdata/**", ]), ) @@ -53,7 +53,7 @@ py_test( name = "dataset_test", srcs = ["dataset_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], deps = [ @@ -136,7 +136,7 @@ py_test( size = "large", srcs = ["gesture_recognizer_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, @@ -151,7 +151,7 @@ py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], data = [ - ":test_data", + ":testdata", ], deps = [ ":metadata_writer", @@ -164,7 +164,7 @@ py_binary( name = "gesture_recognizer_demo", srcs = ["gesture_recognizer_demo.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], python_version = "PY3", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py index 06075fbc6..1cf9f0619 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py @@ -31,7 +31,7 @@ FLAGS = flags.FLAGS # TODO: Move hand gesture recognizer demo dataset to an # open-sourced directory. -TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data' +TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data' def define_flags(): diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 9cee88362..280fc6a82 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -25,7 +25,7 @@ from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' tf.keras.backend.experimental.enable_tf_random_generator() diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index e1101e066..83998141d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -23,7 +23,7 @@ from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writ from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata" +_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata" _EXPECTED_JSON = test_utils.get_test_data_path( os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json")) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg From 1e2cb2b35968100e6ec6cd974c2ec01e7bf6be9e Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Thu, 1 Dec 2022 11:33:15 -0800 Subject: [PATCH 151/469] Internal change PiperOrigin-RevId: 492253867 --- mediapipe/framework/input_stream_handler.cc | 4 +- .../immediate_input_stream_handler_test.cc | 37 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index d1dffa414..a7bd9ef43 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -354,7 +354,9 @@ NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) { } } *min_stream_timestamp = std::min(min_packet, min_bound); - if (*min_stream_timestamp == Timestamp::Done()) { + if (*min_stream_timestamp >= Timestamp::OneOverPostStream()) { + // Either OneOverPostStream or Done indicates no more packets. + *min_stream_timestamp = Timestamp::Done(); last_processed_ts_ = Timestamp::Done().PreviousAllowedInStream(); return NodeReadiness::kReadyForClose; } diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index e721afb02..e5de7f0c9 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -230,6 +230,43 @@ TEST_F(ImmediateInputStreamHandlerTest, StreamDoneReady) { input_stream_handler_->ClearCurrentInputs(cc_); } +// This test checks that the state is ReadyForClose after all streams reach +// Timestamp::Max. +TEST_F(ImmediateInputStreamHandlerTest, ReadyForCloseAfterTimestampMax) { + Timestamp min_stream_timestamp; + std::list packets; + + // One packet arrives, ready for process. + packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(10))); + input_stream_handler_->AddPackets(name_to_id_["input_a"], packets); + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp(10), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + // No packets arrive, not ready. + EXPECT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Unset(), cc_->InputTimestamp()); + + // Timestamp::Max arrives, ready for close. + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_a"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_b"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_c"], Timestamp::Max().NextAllowedInStream()); + + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Done(), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); +} + // This test checks that when any stream is done, the state is ready to close. TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) { Timestamp min_stream_timestamp; From 40eb0e63858bd6c8746f4d5127a76ebef1f71cf7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Thu, 1 Dec 2022 12:57:07 -0800 Subject: [PATCH 152/469] Internal change PiperOrigin-RevId: 492276913 --- mediapipe/gpu/multi_pool.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h index 8a3cf6be0..e677c3bbf 100644 --- a/mediapipe/gpu/multi_pool.h +++ b/mediapipe/gpu/multi_pool.h @@ -59,6 +59,8 @@ class MultiPool { MultiPool(SimplePoolFactory factory = DefaultMakeSimplePool, MultiPoolOptions options = kDefaultMultiPoolOptions) : create_simple_pool_(factory), options_(options) {} + explicit MultiPool(MultiPoolOptions options) + : MultiPool(DefaultMakeSimplePool, options) {} // Obtains an item. May either be reused or created anew. Item Get(const Spec& spec); From fd79f18aeb41d78966a91dbd38107534c3fb29e8 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Thu, 1 Dec 2022 14:13:01 -0800 Subject: [PATCH 153/469] Make BaseOptions to pass absolute path to C++ layer. PiperOrigin-RevId: 492296573 --- mediapipe/tasks/python/core/base_options.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/core/base_options.py b/mediapipe/tasks/python/core/base_options.py index 122dc620f..b48fa2ccc 100644 --- a/mediapipe/tasks/python/core/base_options.py +++ b/mediapipe/tasks/python/core/base_options.py @@ -14,6 +14,7 @@ """Base options for MediaPipe Task APIs.""" import dataclasses +import os from typing import Any, Optional from mediapipe.tasks.cc.core.proto import base_options_pb2 @@ -49,10 +50,14 @@ class BaseOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _BaseOptionsProto: """Generates a BaseOptions protobuf object.""" + if self.model_asset_path is not None: + full_path = os.path.abspath(self.model_asset_path) + else: + full_path = None + return _BaseOptionsProto( model_asset=_ExternalFileProto( - file_name=self.model_asset_path, - file_content=self.model_asset_buffer)) + file_name=full_path, file_content=self.model_asset_buffer)) @classmethod @doc_controls.do_not_generate_docs From af990c3da1633f164ccf8f75edb0683079b0c005 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 14:58:30 -0800 Subject: [PATCH 154/469] Open up the visibility of "//mediapipe/java/com/google/mediapipe/framework/image:image". PiperOrigin-RevId: 492308109 --- mediapipe/java/com/google/mediapipe/framework/image/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index bb3be318d..d9508c1f7 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -20,9 +20,7 @@ android_library( name = "image", srcs = glob(["*.java"]), manifest = "AndroidManifest.xml", - visibility = [ - "//mediapipe:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ "//third_party:androidx_legacy_support_v4", "//third_party:autovalue", From ead41132a856379a9a7d22f29abe471dc11f2b4a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 15:00:00 -0800 Subject: [PATCH 155/469] Load model file content from model file path with the help of GetResourceContents in browsers. This can handle the model files that are provided via a custom ResourceProviderFn. PiperOrigin-RevId: 492308453 --- mediapipe/tasks/cc/core/model_resources.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index 618761f32..d5c12ee95 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -99,11 +99,21 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { +#ifdef __EMSCRIPTEN__ + // In browsers, the model file may require a custom ResourceProviderFn to + // provide the model content. The open() method may not work in this case. + // Thus, loading the model content from the model file path in advance with + // the help of GetResourceContents. + MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); +#else // If the model file name is a relative path, searches the file in a // platform-specific location and returns the absolute path on success. ASSIGN_OR_RETURN(std::string path_to_resource, mediapipe::PathToResourceAsFile(model_file_->file_name())); model_file_->set_file_name(path_to_resource); +#endif // __EMSCRIPTEN__ } ASSIGN_OR_RETURN( model_file_handler_, From 768d2dc548f123246d34fe258d6ab75d05c51d3e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 16:47:05 -0800 Subject: [PATCH 156/469] Separate web and java api landmark and world landmark to two classes. This makes the platforms interface consistent. PiperOrigin-RevId: 492332990 --- .../tasks/components/containers/BUILD | 9 +++ .../tasks/components/containers/Landmark.java | 20 +++--- .../containers/NormalizedLandmark.java | 63 +++++++++++++++++++ .../com/google/mediapipe/tasks/vision/BUILD | 2 + .../GestureRecognizerResult.java | 45 ++++++------- .../handlandmarker/HandLandmarkerResult.java | 52 +++++++-------- .../GestureRecognizerTest.java | 4 +- .../handlandmarker/HandLandmarkerTest.java | 4 +- .../web/components/containers/landmark.d.ts | 28 ++++++--- .../gesture_recognizer/gesture_recognizer.ts | 12 ++-- .../gesture_recognizer_result.d.ts | 4 +- .../vision/hand_landmarker/hand_landmarker.ts | 10 ++- .../hand_landmarker_result.d.ts | 4 +- 13 files changed, 161 insertions(+), 96 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..ad17d5552 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -83,6 +83,15 @@ android_library( ], ) +android_library( + name = "normalized_landmark", + srcs = ["NormalizedLandmark.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index e45866190..7fb1b99d0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -18,16 +18,16 @@ import com.google.auto.value.AutoValue; import java.util.Objects; /** - * Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the - * landmark coordinates is normalized respect to the dimension of image, and the coordinates values - * are in the range of [0,1]. Otherwise, it represenet a point in world coordinates. + * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in + * meters. z represents the landmark depth, and the smaller the value the closer the world landmark + * is to the camera. */ @AutoValue public abstract class Landmark { private static final float TOLERANCE = 1e-6f; - public static Landmark create(float x, float y, float z, boolean normalized) { - return new AutoValue_Landmark(x, y, z, normalized); + public static Landmark create(float x, float y, float z) { + return new AutoValue_Landmark(x, y, z); } // The x coordinates of the landmark. @@ -39,28 +39,24 @@ public abstract class Landmark { // The z coordinates of the landmark. public abstract float z(); - // Whether this landmark is normalized with respect to the image size. - public abstract boolean normalized(); - @Override public final boolean equals(Object o) { if (!(o instanceof Landmark)) { return false; } Landmark other = (Landmark) o; - return other.normalized() == this.normalized() - && Math.abs(other.x() - this.x()) < TOLERANCE + return Math.abs(other.x() - this.x()) < TOLERANCE && Math.abs(other.x() - this.y()) < TOLERANCE && Math.abs(other.x() - this.z()) < TOLERANCE; } @Override public final int hashCode() { - return Objects.hash(x(), y(), z(), normalized()); + return Objects.hash(x(), y(), z()); } @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java new file mode 100644 index 000000000..e77f3c3d4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -0,0 +1,63 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Objects; + +/** + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are + * normalized to [0.0, 1.0] by the image width and height respectively. z represents the landmark + * depth, and the smaller the value the closer the landmark is to the camera. The magnitude of z + * uses roughly the same scale as x. + */ +@AutoValue +public abstract class NormalizedLandmark { + private static final float TOLERANCE = 1e-6f; + + public static NormalizedLandmark create(float x, float y, float z) { + return new AutoValue_NormalizedLandmark(x, y, z); + } + + // The x coordinates of the normalized landmark. + public abstract float x(); + + // The y coordinates of the normalized landmark. + public abstract float y(); + + // The z coordinates of the normalized landmark. + public abstract float z(); + + @Override + public final boolean equals(Object o) { + if (!(o instanceof NormalizedLandmark)) { + return false; + } + NormalizedLandmark other = (NormalizedLandmark) o; + return Math.abs(other.x() - this.x()) < TOLERANCE + && Math.abs(other.x() - this.y()) < TOLERANCE + && Math.abs(other.x() - this.z()) < TOLERANCE; + } + + @Override + public final int hashCode() { + return Objects.hash(x(), y(), z()); + } + + @Override + public final String toString() { + return ""; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index b61c174fe..6161fe032 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -135,6 +135,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", @@ -167,6 +168,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:androidx_annotation_annotation", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java index ef76bf226..90b92175d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -43,41 +42,36 @@ public abstract class GestureRecognizerResult implements TaskResult { * @param gesturesProto a List of {@link ClassificationList} */ static GestureRecognizerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, List gesturesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); List> multiHandGestures = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + com.google.mediapipe.tasks.components.containers.NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -118,11 +112,10 @@ public abstract class GestureRecognizerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 2889b0e0b..9092c0a2d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.handlandmarker; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -32,47 +31,41 @@ import java.util.List; public abstract class HandLandmarkerResult implements TaskResult { /** - * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and - * handedness protobuf messages. + * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness + * protobuf messages. * * @param landmarksProto a List of {@link NormalizedLandmarkList} * @param worldLandmarksProto a List of {@link LandmarkList} * @param handednessesProto a List of {@link ClassificationList} */ static HandLandmarkerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { - List handWorldLandmarks = - new ArrayList<>(); + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -98,11 +91,10 @@ public abstract class HandLandmarkerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index c0be4cffe..5821b36cc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -603,7 +603,7 @@ public class GestureRecognizerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java index 9e12d210f..c313d385d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -399,7 +399,7 @@ public class HandLandmarkerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index c887303d0..0f916bf88 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -15,10 +15,27 @@ */ /** - * Landmark represents a point in 3D space with x, y, z coordinates. If - * normalized is true, the landmark coordinates is normalized respect to the - * dimension of image, and the coordinates values are in the range of [0,1]. - * Otherwise, it represenet a point in world coordinates. + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. + * x and y are normalized to [0.0, 1.0] by the image width and height + * respectively. z represents the landmark depth, and the smaller the value the + * closer the landmark is to the camera. The magnitude of z uses roughly the + * same scale as x. + */ +export declare interface NormalizedLandmark { + /** The x coordinates of the normalized landmark. */ + x: number; + + /** The y coordinates of the normalized landmark. */ + y: number; + + /** The z coordinates of the normalized landmark. */ + z: number; +} + +/** + * Landmark represents a point in 3D space with x, y, z coordinates. The + * landmark coordinates are in meters. z represents the landmark depth, + * and the smaller the value the closer the world landmark is to the camera. */ export declare interface Landmark { /** The x coordinates of the landmark. */ @@ -29,7 +46,4 @@ export declare interface Landmark { /** The z coordinates of the landmark. */ z: number; - - /** Whether this landmark is normalized with respect to the image size. */ - normalized: boolean; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 9ec63b07a..15b6acb1a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -27,7 +27,7 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; @@ -67,7 +67,7 @@ FULL_IMAGE_RECT.setHeight(1); export class GestureRecognizer extends VisionTaskRunner { private gestures: Category[][] = []; - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -306,13 +306,12 @@ export class GestureRecognizer extends for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, - z: handLandmarkProto.getZ() ?? 0, - normalized: true + z: handLandmarkProto.getZ() ?? 0 }); } this.landmarks.push(landmarks); @@ -333,8 +332,7 @@ export class GestureRecognizer extends worldLandmarks.push({ x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, - z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false + z: handWorldLandmarkProto.getZ() ?? 0 }); } this.worldLandmarks.push(worldLandmarks); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index 7c295c9e9..e570270b2 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ export declare interface GestureRecognizerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 290f49455..c657275bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -24,7 +24,7 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; @@ -59,7 +59,7 @@ FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ export class HandLandmarker extends VisionTaskRunner { - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -255,13 +255,12 @@ export class HandLandmarker extends VisionTaskRunner { for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, z: handLandmarkProto.getZ() ?? 0, - normalized: true }); } this.landmarks.push(landmarks); @@ -269,7 +268,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** - * Converts raw data into a landmark, and adds it to our worldLandmarks + * Converts raw data into a world landmark, and adds it to our worldLandmarks * list. */ private adddJsWorldLandmarks(data: Uint8Array[]): void { @@ -283,7 +282,6 @@ export class HandLandmarker extends VisionTaskRunner { x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false }); } this.worldLandmarks.push(worldLandmarks); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 044bdfbe7..89f867d69 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ export declare interface HandLandmarkerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; From dabc2af15baad67d92ac5e9d1b2b2a588167664f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 12:04:06 -0800 Subject: [PATCH 157/469] Fix base bath loading in Fileset resolver PiperOrigin-RevId: 492526041 --- mediapipe/tasks/web/core/fileset_resolver.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts index 7d68dbc16..d4691243b 100644 --- a/mediapipe/tasks/web/core/fileset_resolver.ts +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -48,16 +48,16 @@ async function createFileset( if (await isSimdSupported()) { return { wasmLoaderPath: - `/${basePath}/${taskName}_wasm_internal.js`, + `${basePath}/${taskName}_wasm_internal.js`, wasmBinaryPath: - `/${basePath}/${taskName}_wasm_internal.wasm`, + `${basePath}/${taskName}_wasm_internal.wasm`, }; } else { return { wasmLoaderPath: - `/${basePath}/${taskName}_wasm_nosimd_internal.js`, - wasmBinaryPath: `/${basePath}/${ - taskName}_wasm_nosimd_internal.wasm`, + `${basePath}/${taskName}_wasm_nosimd_internal.js`, + wasmBinaryPath: + `${basePath}/${taskName}_wasm_nosimd_internal.wasm`, }; } } From da9587033d118eb58672f25c8f2e541ba7037209 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 12:40:59 -0800 Subject: [PATCH 158/469] Move shared code to TaskRunner PiperOrigin-RevId: 492534879 --- .../tasks/web/audio/audio_classifier/BUILD | 3 +- .../audio_classifier/audio_classifier.ts | 38 ++++++++------ .../audio_classifier_options.d.ts | 4 +- .../tasks/web/audio/audio_embedder/BUILD | 1 - .../audio/audio_embedder/audio_embedder.ts | 48 ++++++++--------- .../audio_embedder_options.d.ts | 4 +- mediapipe/tasks/web/audio/core/BUILD | 13 +---- .../web/audio/core/audio_task_options.d.ts | 23 --------- .../tasks/web/audio/core/audio_task_runner.ts | 17 +------ .../tasks/web/components/processors/BUILD | 1 - .../web/components/processors/base_options.ts | 2 +- mediapipe/tasks/web/core/BUILD | 8 +-- .../tasks/web/core/classifier_options.d.ts | 2 - .../tasks/web/core/embedder_options.d.ts | 2 - mediapipe/tasks/web/core/task_runner.ts | 43 ++++++++++------ ..._options.d.ts => task_runner_options.d.ts} | 8 ++- mediapipe/tasks/web/text/core/BUILD | 11 ---- .../web/text/core/text_task_options.d.ts | 23 --------- .../tasks/web/text/text_classifier/BUILD | 5 +- .../text/text_classifier/text_classifier.ts | 51 +++++++++++-------- .../text_classifier_options.d.ts | 4 +- mediapipe/tasks/web/text/text_embedder/BUILD | 4 +- .../web/text/text_embedder/text_embedder.ts | 51 +++++++++++-------- .../text_embedder/text_embedder_options.d.ts | 4 +- mediapipe/tasks/web/vision/core/BUILD | 2 - .../web/vision/core/vision_task_options.d.ts | 8 +-- .../web/vision/core/vision_task_runner.ts | 15 ++---- .../gesture_recognizer/gesture_recognizer.ts | 30 +++++------ .../vision/hand_landmarker/hand_landmarker.ts | 30 +++++------ .../image_classifier/image_classifier.ts | 38 ++++++++------ .../vision/image_embedder/image_embedder.ts | 38 ++++++++------ .../vision/object_detector/object_detector.ts | 36 +++++++------ 32 files changed, 262 insertions(+), 305 deletions(-) delete mode 100644 mediapipe/tasks/web/audio/core/audio_task_options.d.ts rename mediapipe/tasks/web/core/{base_options.d.ts => task_runner_options.d.ts} (85%) delete mode 100644 mediapipe/tasks/web/text/core/BUILD delete mode 100644 mediapipe/tasks/web/text/core/text_task_options.d.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index c419d3b98..6f785dd0d 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -36,7 +36,6 @@ mediapipe_ts_declaration( "audio_classifier_result.d.ts", ], deps = [ - "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index e606019f2..4e12780d2 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; @@ -56,13 +56,12 @@ export class AudioClassifier extends AudioTaskRunner { * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - const classifier = await TaskRunner.createInstance( - AudioClassifier, /* initializeCanvas= */ false, wasmFileset); - await classifier.setOptions(audioClassifierOptions); - return classifier; + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + audioClassifierOptions); } /** @@ -75,8 +74,9 @@ export class AudioClassifier extends AudioTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -86,20 +86,26 @@ export class AudioClassifier extends AudioTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts index 975b1e315..dc3c494bf 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts @@ -14,9 +14,9 @@ * limitations under the License. */ -import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Audio Classifier Task */ export declare interface AudioClassifierOptions extends ClassifierOptions, - AudioTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 1a66464bd..0555bb639 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -36,7 +36,6 @@ mediapipe_ts_declaration( "audio_embedder_result.d.ts", ], deps = [ - "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index c87aceabe..d08eb4791 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -25,7 +25,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioEmbedderOptions} from './audio_embedder_options'; @@ -58,23 +58,12 @@ export class AudioEmbedder extends AudioTaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, audioEmbedderOptions: AudioEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmFileset.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - AudioEmbedder, wasmFileset.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await embedder.setOptions(audioEmbedderOptions); - return embedder; + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + audioEmbedderOptions); } /** @@ -87,8 +76,9 @@ export class AudioEmbedder extends AudioTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -98,20 +88,26 @@ export class AudioEmbedder extends AudioTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts index 98f412d0f..ac22728ab 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -14,9 +14,9 @@ * limitations under the License. */ -import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Audio Embedder Task */ export declare interface AudioEmbedderOptions extends EmbedderOptions, - AudioTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index 91ebbf524..9ab6c7bee 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -1,24 +1,13 @@ # This package contains options shared by all MediaPipe Audio Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_declaration( - name = "audio_task_options", - srcs = ["audio_task_options.d.ts"], - deps = [ - "//mediapipe/tasks/web/core", - ], -) - mediapipe_ts_library( name = "audio_task_runner", srcs = ["audio_task_runner.ts"], deps = [ - ":audio_task_options", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", ], diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts deleted file mode 100644 index e3068625d..000000000 --- a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -/** The options for configuring a MediaPipe Audio Task. */ -export declare interface AudioTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; -} diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index ceff3895b..00cfe0253 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -14,26 +14,13 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; - -import {AudioTaskOptions} from './audio_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Base class for all MediaPipe Audio Tasks. */ -export abstract class AudioTaskRunner extends TaskRunner { - protected abstract baseOptions?: BaseOptionsProto|undefined; +export abstract class AudioTaskRunner extends TaskRunner { private defaultSampleRate = 48000; - /** Configures the shared options of an audio task. */ - async setOptions(options: AudioTaskOptions): Promise { - this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); - } - } - /** * Sets the sample rate for API calls that omit an explicit sample rate. * `48000` is used as a default if this method is not called. diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 1b56bf4c9..86e743928 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -17,7 +17,6 @@ mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], deps = [ - "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/web/components/containers:classification_result", ], diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index ac24a8db6..16d562262 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -18,7 +18,7 @@ import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inferen import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index d709e3409..de429690d 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -7,18 +7,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_declaration( name = "core", srcs = [ - "base_options.d.ts", + "task_runner_options.d.ts", "wasm_fileset.d.ts", ], ) mediapipe_ts_library( name = "task_runner", - srcs = [ - "task_runner.ts", - ], + srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts index 1d804d629..08e7a7664 100644 --- a/mediapipe/tasks/web/core/classifier_options.d.ts +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -14,8 +14,6 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - /** Options to configure a MediaPipe Classifier Task. */ export declare interface ClassifierOptions { /** diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts index 3ec2a170c..8669acfcb 100644 --- a/mediapipe/tasks/web/core/embedder_options.d.ts +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -14,8 +14,6 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - /** Options to configure a MediaPipe Embedder Task */ export declare interface EmbedderOptions { /** diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 4085be697..c2691fc76 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,6 +14,9 @@ * limitations under the License. */ +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; +import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -28,7 +31,9 @@ const WasmMediaPipeImageLib = SupportModelResourcesGraphService(SupportImage(GraphRunner)); /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends WasmMediaPipeImageLib { +export abstract class TaskRunner extends + WasmMediaPipeImageLib { + protected abstract baseOptions: BaseOptionsProto; private processingErrors: Error[] = []; /** @@ -36,9 +41,10 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ - protected static async createInstance( + protected static async createInstance, + O extends TaskRunnerOptions>( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset): Promise { + fileset: WasmFileset, options: O): Promise { const fileLocator: FileLocator = { locateFile() { // The only file loaded with this mechanism is the Wasm binary @@ -46,19 +52,16 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { } }; - if (initializeCanvas) { - // Fall back to an OffscreenCanvas created by the GraphRunner if - // OffscreenCanvas is available - const canvas = typeof OffscreenCanvas === 'undefined' ? - document.createElement('canvas') : - undefined; - return createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); - } else { - return createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null, - fileLocator); - } + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; } constructor( @@ -74,6 +77,14 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { this.registerModelResourcesGraphService(); } + /** Configures the shared options of a MediaPipe Task. */ + async setOptions(options: O): Promise { + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + } + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts similarity index 85% rename from mediapipe/tasks/web/core/base_options.d.ts rename to mediapipe/tasks/web/core/task_runner_options.d.ts index 86635b8c7..aa0b4a028 100644 --- a/mediapipe/tasks/web/core/base_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -16,7 +16,7 @@ // Placeholder for internal dependency on trusted resource url -/** Options to configure MediaPipe Tasks in general. */ +/** Options to configure MediaPipe model loading and processing. */ export declare interface BaseOptions { /** * The model path to the model asset file. Only one of `modelAssetPath` or @@ -33,3 +33,9 @@ export declare interface BaseOptions { /** Overrides the default backend to use for the provided model. */ delegate?: 'cpu'|'gpu'|undefined; } + +/** Options to configure MediaPipe Tasks in general. */ +export declare interface TaskRunnerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; +} diff --git a/mediapipe/tasks/web/text/core/BUILD b/mediapipe/tasks/web/text/core/BUILD deleted file mode 100644 index 3e7faec93..000000000 --- a/mediapipe/tasks/web/text/core/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -# This package contains options shared by all MediaPipe Texxt Tasks for Web. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_declaration( - name = "text_task_options", - srcs = ["text_task_options.d.ts"], - deps = ["//mediapipe/tasks/web/core"], -) diff --git a/mediapipe/tasks/web/text/core/text_task_options.d.ts b/mediapipe/tasks/web/text/core/text_task_options.d.ts deleted file mode 100644 index 4874e35bf..000000000 --- a/mediapipe/tasks/web/text/core/text_task_options.d.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -/** The options for configuring a MediaPipe Text task. */ -export declare interface TextTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; -} diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index f3d272daa..2a7de21d6 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -17,15 +17,16 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -38,7 +39,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 197869a36..bd2a207ce 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -17,12 +17,13 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; @@ -40,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); @@ -53,13 +54,12 @@ export class TextClassifier extends TaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - const classifier = await TaskRunner.createInstance( - TextClassifier, /* initializeCanvas= */ false, wasmFileset); - await classifier.setOptions(textClassifierOptions); - return classifier; + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + textClassifierOptions); } /** @@ -72,8 +72,9 @@ export class TextClassifier extends TaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -83,13 +84,19 @@ export class TextClassifier extends TaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -101,18 +108,20 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - async setOptions(options: TextClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: TextClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } /** * Performs Natural Language classification on the provided text and waits diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts index b50767e1a..25592deb5 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -15,8 +15,8 @@ */ import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; -import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Text Classifier Task */ export declare interface TextClassifierOptions extends ClassifierOptions, - TextTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index b858f6b83..17d105258 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -17,15 +17,16 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -39,6 +40,5 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 511fd2411..d2899fbe2 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -17,14 +17,15 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; @@ -44,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); @@ -57,13 +58,12 @@ export class TextEmbedder extends TaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, textEmbedderOptions: TextEmbedderOptions): Promise { - const embedder = await TaskRunner.createInstance( - TextEmbedder, /* initializeCanvas= */ false, wasmFileset); - await embedder.setOptions(textEmbedderOptions); - return embedder; + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + textEmbedderOptions); } /** @@ -76,8 +76,9 @@ export class TextEmbedder extends TaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,13 +88,19 @@ export class TextEmbedder extends TaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -105,17 +112,21 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - async setOptions(options: TextEmbedderOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: TextEmbedderOptions): Promise { + await super.setOptions(options); this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + /** * Performs embeding extraction on the provided text and waits synchronously * for the response. diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts index 9ea570304..7689ee0c1 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -15,8 +15,8 @@ */ import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; -import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Text Embedder Task */ export declare interface TextEmbedderOptions extends EmbedderOptions, - TextTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 1d8944f14..b389a9b01 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -17,8 +17,6 @@ mediapipe_ts_library( srcs = ["vision_task_runner.ts"], deps = [ ":vision_task_options", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index e04eb6596..76c0177a0 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** * The two running modes of a vision task. @@ -23,12 +23,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; */ export type RunningMode = 'image'|'video'; - /** The options for configuring a MediaPipe vision task. */ -export declare interface VisionTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface VisionTaskOptions extends TaskRunnerOptions { /** * The running mode of the task. Default to the image mode. * Vision tasks have two running modes: diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 79ff45156..78b4859f2 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -14,24 +14,17 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends TaskRunner { - protected abstract baseOptions?: BaseOptionsProto|undefined; - +export abstract class VisionTaskRunner extends + TaskRunner { /** Configures the shared options of a vision task. */ - async setOptions(options: VisionTaskOptions): Promise { - this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); - } + override async setOptions(options: VisionTaskOptions): Promise { + await super.setOptions(options); if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 15b6acb1a..8baee5ce3 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -88,14 +88,13 @@ export class GestureRecognizer extends * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - const recognizer = await VisionTaskRunner.createInstance( - GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); - await recognizer.setOptions(gestureRecognizerOptions); - return recognizer; + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + gestureRecognizerOptions); } /** @@ -108,8 +107,9 @@ export class GestureRecognizer extends static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return GestureRecognizer.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -119,13 +119,12 @@ export class GestureRecognizer extends * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return GestureRecognizer.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } constructor( @@ -134,6 +133,7 @@ export class GestureRecognizer extends super(wasmModule, glCanvas); this.options = new GestureRecognizerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); this.handLandmarksDetectorGraphOptions = @@ -151,11 +151,11 @@ export class GestureRecognizer extends this.initDefaults(); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index c657275bf..263ed4b48 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -77,13 +77,12 @@ export class HandLandmarker extends VisionTaskRunner { * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, handLandmarkerOptions: HandLandmarkerOptions): Promise { - const landmarker = await VisionTaskRunner.createInstance( - HandLandmarker, /* initializeCanvas= */ true, wasmFileset); - await landmarker.setOptions(handLandmarkerOptions); - return landmarker; + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + handLandmarkerOptions); } /** @@ -96,8 +95,9 @@ export class HandLandmarker extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return HandLandmarker.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -107,13 +107,12 @@ export class HandLandmarker extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return HandLandmarker.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } constructor( @@ -122,6 +121,7 @@ export class HandLandmarker extends VisionTaskRunner { super(wasmModule, glCanvas); this.options = new HandLandmarkerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarksDetectorGraphOptions = new HandLandmarksDetectorGraphOptions(); this.options.setHandLandmarksDetectorGraphOptions( @@ -132,11 +132,11 @@ export class HandLandmarker extends VisionTaskRunner { this.initDefaults(); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 185ddf9ea..90dbf9798 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -55,13 +55,12 @@ export class ImageClassifier extends VisionTaskRunner { * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - const classifier = await VisionTaskRunner.createInstance( - ImageClassifier, /* initializeCanvas= */ true, wasmFileset); - await classifier.setOptions(imageClassifierOptions); - return classifier; + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + imageClassifierOptions); } /** @@ -74,8 +73,9 @@ export class ImageClassifier extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -85,20 +85,26 @@ export class ImageClassifier extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 91352e934..559332650 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -25,7 +25,7 @@ import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/ import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; @@ -57,13 +57,12 @@ export class ImageEmbedder extends VisionTaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, imageEmbedderOptions: ImageEmbedderOptions): Promise { - const embedder = await VisionTaskRunner.createInstance( - ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); - await embedder.setOptions(imageEmbedderOptions); - return embedder; + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + imageEmbedderOptions); } /** @@ -76,8 +75,9 @@ export class ImageEmbedder extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,20 +87,26 @@ export class ImageEmbedder extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 7711c39e9..03171003f 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; @@ -54,13 +54,12 @@ export class ObjectDetector extends VisionTaskRunner { * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - const detector = await VisionTaskRunner.createInstance( - ObjectDetector, /* initializeCanvas= */ true, wasmFileset); - await detector.setOptions(objectDetectorOptions); - return detector; + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + objectDetectorOptions); } /** @@ -73,8 +72,9 @@ export class ObjectDetector extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ObjectDetector.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,17 +87,23 @@ export class ObjectDetector extends VisionTaskRunner { static async createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ObjectDetector.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } From e457039fc6350fbd2e75aa2d034f9b68af6d3410 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 16:16:34 -0800 Subject: [PATCH 159/469] Don't inherit from GraphRunner PiperOrigin-RevId: 492584486 --- .../audio_classifier/audio_classifier.ts | 9 +++-- .../audio/audio_embedder/audio_embedder.ts | 25 ++++++++------ mediapipe/tasks/web/core/task_runner.ts | 24 +++++++------- .../text/text_classifier/text_classifier.ts | 11 ++++--- .../web/text/text_embedder/text_embedder.ts | 4 +-- .../gesture_recognizer/gesture_recognizer.ts | 33 +++++++++++-------- .../vision/hand_landmarker/hand_landmarker.ts | 26 ++++++++------- .../image_classifier/image_classifier.ts | 11 ++++--- .../vision/image_embedder/image_embedder.ts | 4 +-- .../vision/object_detector/object_detector.ts | 9 ++--- .../graph_runner/graph_runner_image_lib.ts | 2 +- .../register_model_resources_graph_service.ts | 4 +-- 12 files changed, 92 insertions(+), 70 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 4e12780d2..265ba2b33 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioClassifierResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); @@ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoVectorListener( + this.graphRunner.attachProtoVectorListener( TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { this.addJsAudioClassificationResults(binaryProtos); }); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index d08eb4791..445dd5172 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.embeddingResults = []; this.finishProcessing(); @@ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResults.push( convertFromEmbeddingResultProto(embeddingResult)); }); - this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { - for (const binaryProto of data) { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResults.push( - convertFromEmbeddingResultProto(embeddingResult)); - } - }); + this.graphRunner.attachProtoVectorListener( + TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c2691fc76..d769139bc 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset'; const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const WasmMediaPipeImageLib = +const GraphRunnerImageLibType = SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class GraphRunnerImageLib extends GraphRunnerImageLibType {} /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends - WasmMediaPipeImageLib { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; + protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; /** @@ -67,14 +69,14 @@ export abstract class TaskRunner extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. - this.setAutoRenderToScreen(false); + this.graphRunner.setAutoRenderToScreen(false); // Enables use of our model resource caching graph service. - this.registerModelResourcesGraphService(); + this.graphRunner.registerModelResourcesGraphService(); } /** Configures the shared options of a MediaPipe Task. */ @@ -95,11 +97,11 @@ export abstract class TaskRunner extends * @param isBinary This should be set to true if the graph is in * binary format, and false if it is in human-readable text format. */ - override setGraph(graphData: Uint8Array, isBinary: boolean): void { - this.attachErrorListener((code, message) => { + protected setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.graphRunner.attachErrorListener((code, message) => { this.processingErrors.push(new Error(message)); }); - super.setGraph(graphData, isBinary); + this.graphRunner.setGraph(graphData, isBinary); this.handleErrors(); } @@ -108,8 +110,8 @@ export abstract class TaskRunner extends * far as possible, performing all processing until no more processing can be * done. */ - override finishProcessing(): void { - super.finishProcessing(); + protected finishProcessing(): void { + this.graphRunner.finishProcessing(); this.handleErrors(); } diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index bd2a207ce..8810d4b42 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner { classify(text: string): TextClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.classificationResult; @@ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index d2899fbe2..62f9b06db 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner { */ embed(text: string): TextEmbedderResult { // Get text embeddings by running our MediaPipe graph. - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.embeddingResult; @@ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); }); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8baee5ce3..69a8118a6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -257,8 +257,9 @@ export class GestureRecognizer extends this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -365,18 +366,22 @@ export class GestureRecognizer extends graphConfig.addNode(recognizerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HAND_GESTURES_STREAM, binaryProto => { + this.gestures.push(...this.toJsCategories(binaryProto)); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 263ed4b48..9a0823f23 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner { this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner { graphConfig.addNode(landmarkerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 90dbf9798..40e8b5099 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner { ImageClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.classificationResult; @@ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 559332650..f8b0204ee 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner { protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { // Get embeddings by running our MediaPipe graph. - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( image, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.embeddings; @@ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.addJsImageEmdedding(binaryProto); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 03171003f..e2cfe0575 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner { Detection[] { // Get detections by running our MediaPipe graph. this.detections = []; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return [...this.detections]; @@ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner { graphConfig.addNode(detectorNode); - this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { - this.addJsObjectDetections(binaryProto); - }); + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, binaryProto => { + this.addJsObjectDetections(binaryProto); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts index e886999cb..7a4ea09e2 100644 --- a/mediapipe/web/graph_runner/graph_runner_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -22,7 +22,7 @@ export declare interface WasmImageModule { * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` + * `const GraphRunnerImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index bc9c93e8a..9f2791d80 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources { * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: - * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * GraphRunner);` + * `const GraphRunnerWithModelResourcesLib = + * SupportModelResourcesGraphService(GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService( From 35bb18945f21856f62cd99027f7702b92411dfc5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 07:22:51 -0800 Subject: [PATCH 160/469] Better handling of empty packets in vector calculators. PiperOrigin-RevId: 493000695 --- .../core/get_vector_item_calculator.h | 9 +++-- .../core/get_vector_item_calculator.proto | 3 ++ .../core/get_vector_item_calculator_test.cc | 34 ++++++++++++++----- .../core/merge_to_vector_calculator.cc | 4 +++ .../core/merge_to_vector_calculator.h | 15 ++++++-- 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index dc98ccfe7..25d90bfe6 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); absl::Status Open(CalculatorContext* cc) final { + cc->SetOffset(mediapipe::TimestampDiff(0)); auto& options = cc->Options(); RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index()); return absl::OkStatus(); @@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node { return absl::OkStatus(); } - RET_CHECK(idx >= 0 && idx < items.size()); - kOut(cc).Send(items[idx]); + RET_CHECK(idx >= 0); + RET_CHECK(options.output_empty_on_oob() || idx < items.size()); + + if (idx < items.size()) { + kOut(cc).Send(items[idx]); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/get_vector_item_calculator.proto b/mediapipe/calculators/core/get_vector_item_calculator.proto index c406283e4..9cfb579e4 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.proto +++ b/mediapipe/calculators/core/get_vector_item_calculator.proto @@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions { // Index of vector item to get. INDEX input stream can be used instead, or to // override. optional int32 item_index = 1; + + // Set to true to output an empty packet when the index is out of bounds. + optional bool output_empty_on_oob = 2; } diff --git a/mediapipe/calculators/core/get_vector_item_calculator_test.cc b/mediapipe/calculators/core/get_vector_item_calculator_test.cc index c148aa9d1..c2974e20a 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator_test.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator_test.cc @@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() { )"); } -CalculatorRunner MakeRunnerWithOptions(int set_index) { - return CalculatorRunner(absl::StrFormat(R"( +CalculatorRunner MakeRunnerWithOptions(int set_index, + bool output_empty_on_oob = false) { + return CalculatorRunner( + absl::StrFormat(R"( calculator: "TestGetIntVectorItemCalculator" input_stream: "VECTOR:vector_stream" output_stream: "ITEM:item_stream" options { [mediapipe.GetVectorItemCalculatorOptions.ext] { item_index: %d + output_empty_on_oob: %s } } )", - set_index)); + set_index, output_empty_on_oob ? "true" : "false")); } void AddInputVector(CalculatorRunner& runner, const std::vector& inputs, @@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { @@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { @@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { @@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); +} + +TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) { + const int try_index = 3; + CalculatorRunner runner = MakeRunnerWithOptions(try_index, true); + const std::vector inputs = {1, 2, 3}; + + AddInputVector(runner, inputs, 1); + + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Tag("ITEM").packets; + EXPECT_THAT(outputs, testing::ElementsAre()); } TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) { diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index cca64bc9a..5f05ad725 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -23,5 +23,9 @@ namespace api2 { typedef MergeToVectorCalculator MergeImagesToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator); +typedef MergeToVectorCalculator + MergeGpuBuffersToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.h b/mediapipe/calculators/core/merge_to_vector_calculator.h index bed616695..f63d86ee4 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.h +++ b/mediapipe/calculators/core/merge_to_vector_calculator.h @@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node { return absl::OkStatus(); } + absl::Status Open(::mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + absl::Status Process(CalculatorContext* cc) { const int input_num = kIn(cc).Count(); - std::vector output_vector(input_num); - std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(), - [](const auto& elem) -> T { return elem.Get(); }); + std::vector output_vector; + for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) { + const auto& elem = *it; + if (!elem.IsEmpty()) { + output_vector.push_back(elem.Get()); + } + } kOut(cc).Send(output_vector); return absl::OkStatus(); } From 4f8eaee20f5c02d932b8bacecd1afb0655d84130 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Mon, 5 Dec 2022 11:33:21 -0800 Subject: [PATCH 161/469] Internal change PiperOrigin-RevId: 493065632 --- mediapipe/graphs/iris_tracking/calculators/BUILD | 1 - mediapipe/java/com/google/mediapipe/framework/jni/BUILD | 7 +++---- mediapipe/modules/hand_landmark/calculators/BUILD | 1 - mediapipe/modules/objectron/calculators/BUILD | 4 ---- mediapipe/util/tracking/BUILD | 1 - 5 files changed, 3 insertions(+), 11 deletions(-) diff --git a/mediapipe/graphs/iris_tracking/calculators/BUILD b/mediapipe/graphs/iris_tracking/calculators/BUILD index 3a3d57a0f..f5124b464 100644 --- a/mediapipe/graphs/iris_tracking/calculators/BUILD +++ b/mediapipe/graphs/iris_tracking/calculators/BUILD @@ -97,7 +97,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:image_file_properties_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 4926e2f3c..4540f63a6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -84,12 +84,11 @@ cc_library( deps = [ ":class_registry", ":jni_util", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_profile_cc_proto", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/modules/hand_landmark/calculators/BUILD b/mediapipe/modules/hand_landmark/calculators/BUILD index b2a8efe37..b42ec94de 100644 --- a/mediapipe/modules/hand_landmark/calculators/BUILD +++ b/mediapipe/modules/hand_landmark/calculators/BUILD @@ -24,7 +24,6 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index eeeaee5f4..14cea526f 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -275,7 +275,6 @@ cc_library( ":tflite_tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -299,7 +298,6 @@ cc_library( ":tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -316,13 +314,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_cc_proto", - ":belief_decoder_config_cc_proto", ":decoder", ":lift_2d_frame_annotation_to_3d_calculator_cc_proto", ":tensor_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 6bca24446..816af2533 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -282,7 +282,6 @@ cc_library( srcs = ["motion_models_cv.cc"], hdrs = ["motion_models_cv.h"], deps = [ - ":camera_motion_cc_proto", ":motion_models", ":motion_models_cc_proto", "//mediapipe/framework/port:opencv_core", From 69b27b246a3f11e775791805eb2c2b4858ed9412 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 14:14:13 -0800 Subject: [PATCH 162/469] Adds a public function for creating TaskRunner instances. PiperOrigin-RevId: 493109736 --- mediapipe/tasks/web/core/task_runner.ts | 46 ++++++++++++++++--------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index d769139bc..e2ab42e31 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -32,6 +32,34 @@ const GraphRunnerImageLibType = /** An implementation of the GraphRunner that supports image operations */ export class GraphRunnerImageLib extends GraphRunnerImageLibType {} +/** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ +export async function +createTaskRunner, O extends TaskRunnerOptions>( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset, options: O): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; +} + /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; @@ -47,23 +75,7 @@ export abstract class TaskRunner { O extends TaskRunnerOptions>( type: WasmMediaPipeConstructor, initializeCanvas: boolean, fileset: WasmFileset, options: O): Promise { - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return fileset.wasmBinaryPath.toString(); - } - }; - - // Initialize a canvas if requested. If OffscreenCanvas is availble, we - // let the graph runner initialize it by passing `undefined`. - const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? - document.createElement('canvas') : - undefined) : - null; - const instance = await createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); - await instance.setOptions(options); - return instance; + return createTaskRunner(type, initializeCanvas, fileset, options); } constructor( From 3ad03bee0be95376cf4606d39b201dab5a0afcb5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 14:48:07 -0800 Subject: [PATCH 163/469] Adds missing visibility rule. PiperOrigin-RevId: 493118880 --- mediapipe/calculators/tensor/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 645189a07..577ac4111 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -366,6 +366,9 @@ cc_test( cc_library( name = "universal_sentence_encoder_preprocessor_calculator", srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", From 99d1dd6fbb130f9f262365ae334b2ca22c819478 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 5 Dec 2022 15:28:52 -0800 Subject: [PATCH 164/469] Internal change PiperOrigin-RevId: 493129643 --- docs/build_py_api_docs.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fe706acd3..46546012d 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -26,7 +26,6 @@ from absl import app from absl import flags from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. @@ -68,10 +67,7 @@ def gen_api_docs(): code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - # This callback ensures that docs are only generated for objects that - # are explicitly imported in your __init__.py files. There are other - # options but this is a good starting point. - callbacks=[public_api.explicit_package_contents_filter], + callbacks=[], ) doc_generator.build(_OUTPUT_DIR.value) From 1e76d47a71602ba0ac4a089f625bbd667a7f184b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 5 Dec 2022 16:18:36 -0800 Subject: [PATCH 165/469] Checks if a custom global resource provider is used as the first step of loading the model resources on all platforms. PiperOrigin-RevId: 493141519 --- mediapipe/tasks/cc/core/BUILD | 1 + mediapipe/tasks/cc/core/model_resources.cc | 30 +++++++++++----------- mediapipe/util/resource_util.cc | 2 ++ mediapipe/util/resource_util_custom.h | 3 +++ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 202f3ea3c..f8004d257 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -117,6 +117,7 @@ cc_library_with_tflite( "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/util:resource_util", + "//mediapipe/util:resource_util_custom", "//mediapipe/util/tflite:error_reporter", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index d5c12ee95..7819f6213 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/util/resource_util.h" +#include "mediapipe/util/resource_util_custom.h" #include "mediapipe/util/tflite/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -99,21 +100,20 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { -#ifdef __EMSCRIPTEN__ - // In browsers, the model file may require a custom ResourceProviderFn to - // provide the model content. The open() method may not work in this case. - // Thus, loading the model content from the model file path in advance with - // the help of GetResourceContents. - MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( - model_file_->file_name(), model_file_->mutable_file_content())); - model_file_->clear_file_name(); -#else - // If the model file name is a relative path, searches the file in a - // platform-specific location and returns the absolute path on success. - ASSIGN_OR_RETURN(std::string path_to_resource, - mediapipe::PathToResourceAsFile(model_file_->file_name())); - model_file_->set_file_name(path_to_resource); -#endif // __EMSCRIPTEN__ + if (HasCustomGlobalResourceProvider()) { + // If the model contents are provided via a custom ResourceProviderFn, the + // open() method may not work. Thus, loads the model content from the + // model file path in advance with the help of GetResourceContents. + MP_RETURN_IF_ERROR(GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); + } else { + // If the model file name is a relative path, searches the file in a + // platform-specific location and returns the absolute path on success. + ASSIGN_OR_RETURN(std::string path_to_resource, + PathToResourceAsFile(model_file_->file_name())); + model_file_->set_file_name(path_to_resource); + } } ASSIGN_OR_RETURN( model_file_handler_, diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 8f40154a0..38636f32e 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -37,6 +37,8 @@ absl::Status GetResourceContents(const std::string& path, std::string* output, return internal::DefaultGetResourceContents(path, output, read_as_binary); } +bool HasCustomGlobalResourceProvider() { return resource_provider_ != nullptr; } + void SetCustomGlobalResourceProvider(ResourceProviderFn fn) { resource_provider_ = std::move(fn); } diff --git a/mediapipe/util/resource_util_custom.h b/mediapipe/util/resource_util_custom.h index 6bc1513c6..e74af8b2e 100644 --- a/mediapipe/util/resource_util_custom.h +++ b/mediapipe/util/resource_util_custom.h @@ -10,6 +10,9 @@ namespace mediapipe { typedef std::function ResourceProviderFn; +// Returns true if files are provided via a custom resource provider. +bool HasCustomGlobalResourceProvider(); + // Overrides the behavior of GetResourceContents. void SetCustomGlobalResourceProvider(ResourceProviderFn fn); From 3174b20fbe8225c35433d86f3a82d29645bb82bb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 17:32:57 -0800 Subject: [PATCH 166/469] Move segmentation calculator and options out of 'components' folder. PiperOrigin-RevId: 493157929 --- mediapipe/tasks/cc/components/proto/BUILD | 24 ------------------- .../tasks/cc/vision/image_segmenter/BUILD | 8 +++---- .../image_segmenter/calculators}/BUILD | 4 ++-- .../tensors_to_segmentation_calculator.cc | 14 +++++------ .../tensors_to_segmentation_calculator.proto | 6 +++-- ...tensors_to_segmentation_calculator_test.cc | 4 +--- .../vision/image_segmenter/image_segmenter.cc | 4 ++-- .../image_segmenter/image_segmenter_graph.cc | 6 ++--- .../image_segmenter/image_segmenter_test.cc | 2 +- .../cc/vision/image_segmenter/proto/BUILD | 7 +++++- .../proto/image_segmenter_graph_options.proto | 4 ++-- .../proto/segmenter_options.proto | 4 ++-- mediapipe/tasks/python/vision/BUILD | 2 +- .../tasks/python/vision/image_segmenter.py | 2 +- 14 files changed, 35 insertions(+), 56 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/proto/BUILD rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/BUILD (94%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator.cc (95%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator.proto (82%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator_test.cc (99%) rename mediapipe/tasks/cc/{components => vision/image_segmenter}/proto/segmenter_options.proto (92%) diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD deleted file mode 100644 index 569023753..000000000 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -mediapipe_proto_library( - name = "segmenter_options_proto", - srcs = ["segmenter_options.proto"], -) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 2124fe6e0..4c9c6e69c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -28,7 +28,6 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", @@ -36,6 +35,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -56,17 +56,17 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD similarity index 94% rename from mediapipe/tasks/cc/components/calculators/tensor/BUILD rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index 6e4322a8f..dcd7fb407 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -25,7 +25,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", ], ) @@ -45,7 +45,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/util:label_map_cc_proto", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc similarity index 95% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 40585848f..668de0057 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO consolidate TensorsToSegmentationCalculator. #include #include #include @@ -35,14 +34,14 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/util/label_map.pb.h" +// TODO: consolidate TensorToSegmentationCalculator. namespace mediapipe { namespace tasks { - namespace { using ::mediapipe::Image; @@ -51,9 +50,9 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; void StableSoftmax(absl::Span values, absl::Span activated_values) { @@ -90,7 +89,7 @@ void Sigmoid(absl::Span values, // the size to resize masks to. // // Output: -// Segmentation: Segmenation proto. +// Segmentation: Segmentation proto. // // Options: // See tensors_to_segmentation_calculator.proto @@ -132,8 +131,7 @@ class TensorsToSegmentationCalculator : public Node { absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { - options_ = - cc->Options(); + options_ = cc->Options(); RET_CHECK_NE(options_.segmenter_options().output_type(), SegmenterOptions::UNSPECIFIED) << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto similarity index 82% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index c26cf910a..b0fdfdd32 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -15,10 +15,11 @@ limitations under the License. syntax = "proto2"; +// TODO: consolidate TensorToSegmentationCalculator. package mediapipe.tasks; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; message TensorsToSegmentationCalculatorOptions { @@ -26,7 +27,8 @@ message TensorsToSegmentationCalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; } - optional components.proto.SegmenterOptions segmenter_options = 1; + optional mediapipe.tasks.vision.image_segmenter.proto.SegmenterOptions + segmenter_options = 1; // Identifying information for each classification label. map label_items = 2; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc similarity index 99% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 55e46d72b..54fb9b816 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -33,10 +33,9 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" namespace mediapipe { -namespace api2 { namespace { @@ -374,5 +373,4 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { expected_index, buffer_indices))); } -} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 6dce1b4ea..bbee714c6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -18,12 +18,12 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" namespace mediapipe { namespace tasks { @@ -44,7 +44,7 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; -using ::mediapipe::tasks::components::proto::SegmenterOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index d5eb5af0d..5531968c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -26,16 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -54,10 +54,10 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 752a116dd..d5ea088a1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index 3b14060f1..9523dd679 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -18,13 +18,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "segmenter_options_proto", + srcs = ["segmenter_options.proto"], +) + mediapipe_proto_library( name = "image_segmenter_graph_options_proto", srcs = ["image_segmenter_graph_options.proto"], deps = [ + ":segmenter_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 166e2e8e0..4d8100842 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,8 +18,8 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "ImageSegmenterGraphOptionsProto"; @@ -37,5 +37,5 @@ message ImageSegmenterGraphOptions { optional string display_names_locale = 2 [default = "en"]; // Segmentation output options. - optional components.proto.SegmenterOptions segmenter_options = 3; + optional SegmenterOptions segmenter_options = 3; } diff --git a/mediapipe/tasks/cc/components/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto similarity index 92% rename from mediapipe/tasks/cc/components/proto/segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index ca9986707..be2b8a589 100644 --- a/mediapipe/tasks/cc/components/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.vision.image_segmenter.proto; -option java_package = "com.google.mediapipe.tasks.components.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "SegmenterOptionsProto"; // Shared options used by image segmentation tasks. diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index e94507eed..29e7577e8 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -69,8 +69,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 62fc8bb7c..22a37cb3e 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet -from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 +from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls From af43687f2e3c774ff7b0f1f4881d456952a6aadd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 20:07:29 -0800 Subject: [PATCH 167/469] Open-sources a unit test. PiperOrigin-RevId: 493184055 --- .../text_classifier/text_classifier_test.cc | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 8f73914fc..799885eac 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -38,10 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" -namespace mediapipe { -namespace tasks { -namespace text { -namespace text_classifier { +namespace mediapipe::tasks::text::text_classifier { namespace { using ::mediapipe::file::JoinPath; @@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, } } +} // namespace + class TextClassifierTest : public tflite_shims::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { @@ -217,8 +216,42 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { MP_ASSERT_OK(classifier->Close()); } -} // namespace -} // namespace text_classifier -} // namespace text -} // namespace tasks -} // namespace mediapipe +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, + classifier->Classify(ss_for_positive_review.str())); + TextClassifierResult expected; + std::vector categories; + +// Predicted scores are slightly different on Mac OS. +#ifdef __APPLE__ + categories.push_back( + {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"}); +#else + categories.push_back( + {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"}); +#endif // __APPLE__ + + expected.classifications.emplace_back( + Classifications{/*categories=*/categories, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(result, expected); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace mediapipe::tasks::text::text_classifier From 1761cdcef4ff6fd37d04d15de765eccd7c0a5bcc Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 5 Dec 2022 22:11:00 -0800 Subject: [PATCH 168/469] Fix aar breakage caused by missing "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark". PiperOrigin-RevId: 493204770 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index d91c03cc2..c6aba3c66 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -285,6 +285,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", From b0b38a0013c819a6db4156330cbbe2e0dab11bd8 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 6 Dec 2022 08:29:35 -0800 Subject: [PATCH 169/469] Internal change PiperOrigin-RevId: 493313240 --- mediapipe/gpu/metal_shared_resources.h | 40 +++++++++++ mediapipe/gpu/metal_shared_resources.mm | 73 ++++++++++++++++++++ mediapipe/gpu/metal_shared_resources_test.mm | 49 +++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 mediapipe/gpu/metal_shared_resources.h create mode 100644 mediapipe/gpu/metal_shared_resources.mm create mode 100644 mediapipe/gpu/metal_shared_resources_test.mm diff --git a/mediapipe/gpu/metal_shared_resources.h b/mediapipe/gpu/metal_shared_resources.h new file mode 100644 index 000000000..341860a2d --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.h @@ -0,0 +1,40 @@ +#ifndef MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ +#define MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ + +#import +#import +#import +#import + +#ifndef __OBJC__ +#error This class must be built as Objective-C++. +#endif // !__OBJC__ + +@interface MPPMetalSharedResources : NSObject { +} + +- (instancetype)init NS_DESIGNATED_INITIALIZER; + +@property(readonly) id mtlDevice; +@property(readonly) id mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@property(readonly) CVMetalTextureCacheRef mtlTextureCache; +#endif + +@end + +namespace mediapipe { + +class MetalSharedResources { + public: + MetalSharedResources(); + ~MetalSharedResources(); + MPPMetalSharedResources* resources() { return resources_; } + + private: + MPPMetalSharedResources* resources_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ diff --git a/mediapipe/gpu/metal_shared_resources.mm b/mediapipe/gpu/metal_shared_resources.mm new file mode 100644 index 000000000..80d755a01 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.mm @@ -0,0 +1,73 @@ +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResources () +@end + +@implementation MPPMetalSharedResources { +} + +@synthesize mtlDevice = _mtlDevice; +@synthesize mtlCommandQueue = _mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@synthesize mtlTextureCache = _mtlTextureCache; +#endif + +- (instancetype)init { + self = [super init]; + if (self) { + } + return self; +} + +- (void)dealloc { +#if COREVIDEO_SUPPORTS_METAL + if (_mtlTextureCache) { + CFRelease(_mtlTextureCache); + _mtlTextureCache = NULL; + } +#endif +} + +- (id)mtlDevice { + @synchronized(self) { + if (!_mtlDevice) { + _mtlDevice = MTLCreateSystemDefaultDevice(); + } + } + return _mtlDevice; +} + +- (id)mtlCommandQueue { + @synchronized(self) { + if (!_mtlCommandQueue) { + _mtlCommandQueue = [self.mtlDevice newCommandQueue]; + } + } + return _mtlCommandQueue; +} + +#if COREVIDEO_SUPPORTS_METAL +- (CVMetalTextureCacheRef)mtlTextureCache { + @synchronized(self) { + if (!_mtlTextureCache) { + CVReturn __unused err = + CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err, + self.mtlDevice); + // TODO: register and flush metal caches too. + } + } + return _mtlTextureCache; +} +#endif + +@end + +namespace mediapipe { + +MetalSharedResources::MetalSharedResources() { + resources_ = [[MPPMetalSharedResources alloc] init]; +} +MetalSharedResources::~MetalSharedResources() {} + +} // namespace mediapipe diff --git a/mediapipe/gpu/metal_shared_resources_test.mm b/mediapipe/gpu/metal_shared_resources_test.mm new file mode 100644 index 000000000..9eb53a9b7 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources_test.mm @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/threadpool.h" + +#import "mediapipe/gpu/gpu_shared_data_internal.h" +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResourcesTests : XCTestCase { +} +@end + +@implementation MPPMetalSharedResourcesTests + +// This test verifies that the internal Objective-C object is correctly +// released when the C++ wrapper is released. +- (void)testCorrectlyReleased { + __weak id metalRes = nil; + std::weak_ptr weakGpuRes; + @autoreleasepool { + auto maybeGpuRes = mediapipe::GpuResources::Create(); + XCTAssertTrue(maybeGpuRes.ok()); + weakGpuRes = *maybeGpuRes; + metalRes = (**maybeGpuRes).metal_shared().resources(); + XCTAssertNotEqual(weakGpuRes.lock(), nullptr); + XCTAssertNotNil(metalRes); + } + XCTAssertEqual(weakGpuRes.lock(), nullptr); + XCTAssertNil(metalRes); +} + +@end From fb0b96115f148c8c293f6cc3ddc7b3ed67b8043c Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 08:33:51 -0800 Subject: [PATCH 170/469] Open up mediapipe core calculators' visibility. PiperOrigin-RevId: 493314353 --- mediapipe/calculators/core/BUILD | 88 +--------------------------- mediapipe/calculators/internal/BUILD | 6 +- 2 files changed, 4 insertions(+), 90 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 3b658eb5b..29bca5fa6 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -32,7 +31,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "dequantize_byte_array_calculator_proto", srcs = ["dequantize_byte_array_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_cloner_calculator_proto", srcs = ["packet_cloner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_resampler_calculator_proto", srcs = ["packet_resampler_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_thinner_calculator_proto", srcs = ["packet_thinner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "split_vector_calculator_proto", srcs = ["split_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -82,7 +76,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "quantize_float_vector_calculator_proto", srcs = ["quantize_float_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -92,7 +85,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "sequence_shift_calculator_proto", srcs = ["sequence_shift_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -102,7 +94,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "gate_calculator_proto", srcs = ["gate_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -112,7 +103,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "constant_side_packet_calculator_proto", srcs = ["constant_side_packet_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -124,7 +114,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "clip_vector_size_calculator_proto", srcs = ["clip_vector_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -134,7 +123,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "flow_limiter_calculator_proto", srcs = ["flow_limiter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -144,7 +132,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "graph_profile_calculator_proto", srcs = ["graph_profile_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -154,7 +141,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "get_vector_item_calculator_proto", srcs = ["get_vector_item_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -164,7 +150,6 @@ mediapipe_proto_library( cc_library( name = "add_header_calculator", srcs = ["add_header_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -193,7 +178,6 @@ cc_library( name = "begin_loop_calculator", srcs = ["begin_loop_calculator.cc"], hdrs = ["begin_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -216,7 +200,6 @@ cc_library( name = "end_loop_calculator", srcs = ["end_loop_calculator.cc"], hdrs = ["end_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -258,7 +241,6 @@ cc_test( cc_library( name = "concatenate_vector_calculator_hdr", hdrs = ["concatenate_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -284,7 +266,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -311,7 +292,6 @@ cc_library( cc_library( name = "concatenate_detection_vector_calculator", srcs = ["concatenate_detection_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator", "//mediapipe/framework:calculator_framework", @@ -323,7 +303,6 @@ cc_library( cc_library( name = "concatenate_proto_list_calculator", srcs = ["concatenate_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -372,7 +351,6 @@ cc_library( name = "clip_vector_size_calculator", srcs = ["clip_vector_size_calculator.cc"], hdrs = ["clip_vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,7 +366,6 @@ cc_library( cc_library( name = "clip_detection_vector_size_calculator", srcs = ["clip_detection_vector_size_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator", "//mediapipe/framework:calculator_framework", @@ -415,9 +392,6 @@ cc_test( cc_library( name = "counting_source_calculator", srcs = ["counting_source_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -430,9 +404,6 @@ cc_library( cc_library( name = "make_pair_calculator", srcs = ["make_pair_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -461,9 +432,6 @@ cc_test( cc_library( name = "matrix_multiply_calculator", srcs = ["matrix_multiply_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -477,9 +445,6 @@ cc_library( cc_library( name = "matrix_subtract_calculator", srcs = ["matrix_subtract_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -493,9 +458,6 @@ cc_library( cc_library( name = "mux_calculator", srcs = ["mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -508,9 +470,6 @@ cc_library( cc_library( name = "non_zero_calculator", srcs = ["non_zero_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -556,9 +515,6 @@ cc_test( cc_library( name = "packet_cloner_calculator", srcs = ["packet_cloner_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":packet_cloner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -587,7 +543,6 @@ cc_test( cc_library( name = "packet_inner_join_calculator", srcs = ["packet_inner_join_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -611,7 +566,6 @@ cc_test( cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -643,9 +597,6 @@ cc_test( cc_library( name = "pass_through_calculator", srcs = ["pass_through_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -656,9 +607,6 @@ cc_library( cc_library( name = "round_robin_demux_calculator", srcs = ["round_robin_demux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -670,9 +618,6 @@ cc_library( cc_library( name = "immediate_mux_calculator", srcs = ["immediate_mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -684,7 +629,6 @@ cc_library( cc_library( name = "packet_presence_calculator", srcs = ["packet_presence_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -713,7 +657,6 @@ cc_test( cc_library( name = "previous_loopback_calculator", srcs = ["previous_loopback_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -729,7 +672,6 @@ cc_library( cc_library( name = "flow_limiter_calculator", srcs = ["flow_limiter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -746,7 +688,6 @@ cc_library( cc_library( name = "string_to_int_calculator", srcs = ["string_to_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -759,7 +700,6 @@ cc_library( cc_library( name = "default_side_packet_calculator", srcs = ["default_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -771,7 +711,6 @@ cc_library( cc_library( name = "side_packet_to_stream_calculator", srcs = ["side_packet_to_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -822,9 +761,6 @@ cc_library( name = "packet_resampler_calculator", srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -884,7 +820,6 @@ cc_test( cc_test( name = "matrix_multiply_calculator_test", srcs = ["matrix_multiply_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_multiply_calculator", "//mediapipe/framework:calculator_framework", @@ -900,7 +835,6 @@ cc_test( cc_test( name = "matrix_subtract_calculator_test", srcs = ["matrix_subtract_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_subtract_calculator", "//mediapipe/framework:calculator_framework", @@ -950,7 +884,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -996,7 +929,6 @@ cc_test( cc_library( name = "split_proto_list_calculator", srcs = ["split_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1028,7 +960,6 @@ cc_test( cc_library( name = "dequantize_byte_array_calculator", srcs = ["dequantize_byte_array_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":dequantize_byte_array_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1054,7 +985,6 @@ cc_test( cc_library( name = "quantize_float_vector_calculator", srcs = ["quantize_float_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":quantize_float_vector_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1080,7 +1010,6 @@ cc_test( cc_library( name = "sequence_shift_calculator", srcs = ["sequence_shift_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":sequence_shift_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1105,7 +1034,6 @@ cc_test( cc_library( name = "gate_calculator", srcs = ["gate_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":gate_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1131,7 +1059,6 @@ cc_test( cc_library( name = "matrix_to_vector_calculator", srcs = ["matrix_to_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1167,7 +1094,6 @@ cc_test( cc_library( name = "merge_calculator", srcs = ["merge_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1193,7 +1119,6 @@ cc_test( cc_library( name = "stream_to_side_packet_calculator", srcs = ["stream_to_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -1219,7 +1144,6 @@ cc_test( cc_library( name = "constant_side_packet_calculator", srcs = ["constant_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":constant_side_packet_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1249,7 +1173,6 @@ cc_test( cc_library( name = "graph_profile_calculator", srcs = ["graph_profile_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_profile_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1291,7 +1214,6 @@ cc_library( name = "get_vector_item_calculator", srcs = ["get_vector_item_calculator.cc"], hdrs = ["get_vector_item_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":get_vector_item_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1325,7 +1247,6 @@ cc_library( name = "vector_indices_calculator", srcs = ["vector_indices_calculator.cc"], hdrs = ["vector_indices_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1351,7 +1272,6 @@ cc_library( name = "vector_size_calculator", srcs = ["vector_size_calculator.cc"], hdrs = ["vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1365,9 +1285,6 @@ cc_library( cc_library( name = "packet_sequencer_calculator", srcs = ["packet_sequencer_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:contract", @@ -1402,7 +1319,6 @@ cc_library( name = "merge_to_vector_calculator", srcs = ["merge_to_vector_calculator.cc"], hdrs = ["merge_to_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1416,7 +1332,6 @@ cc_library( mediapipe_proto_library( name = "bypass_calculator_proto", srcs = ["bypass_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1426,7 +1341,6 @@ mediapipe_proto_library( cc_library( name = "bypass_calculator", srcs = ["bypass_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bypass_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index 54b6c20f1..caade2dc3 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -21,7 +21,7 @@ package(default_visibility = ["//visibility:private"]) proto_library( name = "callback_packet_calculator_proto", srcs = ["callback_packet_calculator.proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -29,14 +29,14 @@ mediapipe_cc_proto_library( name = "callback_packet_calculator_cc_proto", srcs = ["callback_packet_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [":callback_packet_calculator_proto"], ) cc_library( name = "callback_packet_calculator", srcs = ["callback_packet_calculator.cc"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":callback_packet_calculator_cc_proto", "//mediapipe/framework:calculator_base", From ab0b0ab558c633bc996c41923f9325269cc76e3c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 10:22:31 -0800 Subject: [PATCH 171/469] Change visibility for MP Tasks Web to public PiperOrigin-RevId: 493343996 --- mediapipe/tasks/web/audio/BUILD | 1 + mediapipe/tasks/web/audio/audio_classifier/BUILD | 2 ++ mediapipe/tasks/web/audio/audio_embedder/BUILD | 2 ++ mediapipe/tasks/web/core/BUILD | 1 + mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/text_classifier/BUILD | 2 ++ mediapipe/tasks/web/text/text_embedder/BUILD | 2 ++ mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/gesture_recognizer/BUILD | 2 ++ mediapipe/tasks/web/vision/hand_landmarker/BUILD | 2 ++ mediapipe/tasks/web/vision/image_classifier/BUILD | 2 ++ mediapipe/tasks/web/vision/image_embedder/BUILD | 2 ++ mediapipe/tasks/web/vision/object_detector/BUILD | 2 ++ 13 files changed, 22 insertions(+) diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index d08602521..9d26f1118 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", "//mediapipe/tasks/web/audio/audio_embedder", diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 6f785dd0d..dc82a4a24 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", srcs = ["audio_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "audio_classifier_options.d.ts", "audio_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 0555bb639..dc84d0cd6 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_embedder", srcs = ["audio_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "audio_embedder_options.d.ts", "audio_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index de429690d..be1b71f5d 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -28,6 +28,7 @@ mediapipe_ts_library( mediapipe_ts_library( name = "fileset_resolver", srcs = ["fileset_resolver.ts"], + visibility = ["//visibility:public"], deps = [":core"], ) diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 159db1a0d..32f43d4b6 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 2a7de21d6..07f78ac20 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", srcs = ["text_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "text_classifier_options.d.ts", "text_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 17d105258..7d796fb7e 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_embedder", srcs = ["text_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "text_embedder_options.d.ts", "text_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 42bc0a494..93493e873 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index ddfd1a327..6e2e56196 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", srcs = ["gesture_recognizer.ts"], + visibility = ["//visibility:public"], deps = [ ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", @@ -42,6 +43,7 @@ mediapipe_ts_declaration( "gesture_recognizer_options.d.ts", "gesture_recognizer_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index fc3e6ef1f..520898e34 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "hand_landmarker", srcs = ["hand_landmarker.ts"], + visibility = ["//visibility:public"], deps = [ ":hand_landmarker_types", "//mediapipe/framework:calculator_jspb_proto", @@ -38,6 +39,7 @@ mediapipe_ts_declaration( "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index ebe64ecf4..848c162ae 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_classifier", srcs = ["image_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "image_classifier_options.d.ts", "image_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 2f012dc5e..6c9d80fb1 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_embedder", srcs = ["image_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "image_embedder_options.d.ts", "image_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 198585258..f73790895 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "object_detector", srcs = ["object_detector.ts"], + visibility = ["//visibility:public"], deps = [ ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", @@ -32,6 +33,7 @@ mediapipe_ts_declaration( "object_detector_options.d.ts", "object_detector_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", From c6e6f9e0b9b35d055cd83016e468a8c30a7b153b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 11:05:47 -0800 Subject: [PATCH 172/469] Fix aar breakage caused by missing "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite". PiperOrigin-RevId: 493357585 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index c6aba3c66..727d020a6 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -21,7 +21,6 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", - "//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", @@ -43,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", From 6deef1a5f13c4af5e38abe96f8aabbba733dcdcb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 12:07:51 -0800 Subject: [PATCH 173/469] Allow specifying tag_suffix in the templated CreateModelResources method. PiperOrigin-RevId: 493375701 --- mediapipe/tasks/cc/core/model_task_graph.cc | 2 +- mediapipe/tasks/cc/core/model_task_graph.h | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 66434483b..0cb556ec2 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -186,7 +186,7 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix) { + std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); bool has_file_pointer_meta = external_file->has_file_pointer_meta(); // if external file is set by file pointer, no need to add the model asset diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 50dcc903b..3068b2c46 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -59,14 +59,16 @@ class ModelTaskGraph : public Subgraph { // creates a local model resources object that can only be used in the graph // construction stage. The returned model resources pointer will provide graph // authors with the access to the metadata extractor and the tflite model. + // If more than one model resources are created in a graph, the model + // resources graph service add the tag_suffix to support multiple resources. template absl::StatusOr CreateModelResources( - SubgraphContext* sc) { + SubgraphContext* sc, std::string tag_suffix = "") { auto external_file = std::make_unique(); external_file->Swap(sc->MutableOptions() ->mutable_base_options() ->mutable_model_asset()); - return CreateModelResources(sc, std::move(external_file)); + return CreateModelResources(sc, std::move(external_file), tag_suffix); } // If the model resources graph service is available, creates a model @@ -83,7 +85,7 @@ class ModelTaskGraph : public Subgraph { // resources. absl::StatusOr CreateModelResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix = ""); + std::string tag_suffix = ""); // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created From cdc14522e2821a60ec1ee208430e364917e21985 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Tue, 6 Dec 2022 13:01:06 -0800 Subject: [PATCH 174/469] Added issue templates for MP Preview. PiperOrigin-RevId: 493389856 --- .../ISSUE_TEMPLATE/11-tasks-issue.md | 25 +++++++++++++++++++ .../ISSUE_TEMPLATE/12-model-maker-issue.md | 25 +++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md create mode 100644 mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md b/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md new file mode 100644 index 000000000..ab7b38368 --- /dev/null +++ b/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md @@ -0,0 +1,25 @@ +--- +name: "Tasks Issue" +about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +labels: type:support + +--- +Please make sure that this is a [Tasks](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- MediaPipe Tasks SDK version: +- Task name (e.g. Object detection, Gesture recognition etc.): +- Programming Language and version ( e.g. C++, Python, Java): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md b/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md new file mode 100644 index 000000000..687360957 --- /dev/null +++ b/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -0,0 +1,25 @@ +--- +name: "Model Maker Issue" +about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +labels: type:support + +--- +Please make sure that this is a [Model Maker](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Python version (e.g. 3.8): +- [MediaPipe Model Maker version](https://pypi.org/project/mediapipe-model-maker/): +- Task name (e.g. Image classification, Gesture recognition etc.): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: From 0f32072804a4e078c9f64ae8cb48d9b1777a679f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 14:01:49 -0800 Subject: [PATCH 175/469] Move ISSUE_TEMPLATAE files to .github folder PiperOrigin-RevId: 493405734 --- .../opensource_only => .github}/ISSUE_TEMPLATE/11-tasks-issue.md | 0 .../ISSUE_TEMPLATE/12-model-maker-issue.md | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {mediapipe/opensource_only => .github}/ISSUE_TEMPLATE/11-tasks-issue.md (100%) rename {mediapipe/opensource_only => .github}/ISSUE_TEMPLATE/12-model-maker-issue.md (100%) diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md similarity index 100% rename from mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md rename to .github/ISSUE_TEMPLATE/11-tasks-issue.md diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md similarity index 100% rename from mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md rename to .github/ISSUE_TEMPLATE/12-model-maker-issue.md From 9bc7b120de85d4991292d831bd844264c783350b Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Tue, 6 Dec 2022 15:12:25 -0800 Subject: [PATCH 176/469] Tweaked the issue templates. PiperOrigin-RevId: 493424927 --- .github/ISSUE_TEMPLATE/11-tasks-issue.md | 2 +- .github/ISSUE_TEMPLATE/12-model-maker-issue.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md index ab7b38368..264371120 100644 --- a/.github/ISSUE_TEMPLATE/11-tasks-issue.md +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -1,6 +1,6 @@ --- name: "Tasks Issue" -about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +about: Use this template for assistance with using MediaPipe Tasks to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md index 687360957..258390d5e 100644 --- a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -1,6 +1,6 @@ --- name: "Model Maker Issue" -about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +about: Use this template for assistance with using MediaPipe Model Maker to create custom on-device ML solutions. labels: type:support --- From fca0f5806b470a47a3c74a7085d32c32a12d61f1 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 15:16:42 -0800 Subject: [PATCH 177/469] Create Build Rules for Apple Frameworks PiperOrigin-RevId: 493426040 --- mediapipe/examples/ios/common/BUILD | 10 ++-- mediapipe/examples/ios/faceeffect/BUILD | 10 ++-- mediapipe/gpu/BUILD | 64 ++++++++-------------- mediapipe/objc/BUILD | 68 ++++++++++------------- third_party/apple_frameworks/BUILD | 73 +++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 91 deletions(-) create mode 100644 third_party/apple_frameworks/BUILD diff --git a/mediapipe/examples/ios/common/BUILD b/mediapipe/examples/ios/common/BUILD index 9b8f8a968..bfa770cec 100644 --- a/mediapipe/examples/ios/common/BUILD +++ b/mediapipe/examples/ios/common/BUILD @@ -29,12 +29,6 @@ objc_library( "Base.lproj/LaunchScreen.storyboard", "Base.lproj/Main.storyboard", ], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], visibility = [ "//mediapipe:__subpackages__", ], @@ -42,6 +36,10 @@ objc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 50a6f68bd..e0c3abb86 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -73,13 +73,11 @@ objc_library( "//mediapipe/modules/face_landmark:face_landmark.tflite", ], features = ["-layering_check"], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 7a8aa6557..f5cb9f715 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -472,13 +472,13 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Accelerate", - "CoreGraphics", - "CoreVideo", - ], visibility = ["//visibility:public"], - deps = ["//mediapipe/objc:util"], + deps = [ + "//mediapipe/objc:util", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreVideo", + ], ) objc_library( @@ -510,13 +510,11 @@ objc_library( "-x objective-c++", "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@com_google_absl//absl/time", "@google_toolbox_for_mac//:GTM_Defines", ], @@ -808,15 +806,13 @@ objc_library( "-Wno-shorten-64-to-32", ], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":gpu_shared_data_internal", ":graph_support", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -1020,16 +1016,14 @@ objc_library( name = "metal_copy_calculator", srcs = ["MetalCopyCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/gpu:copy_calculator_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1038,15 +1032,13 @@ objc_library( name = "metal_rgb_weight_calculator", srcs = ["MetalRgbWeightCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1055,15 +1047,13 @@ objc_library( name = "metal_sobel_calculator", srcs = ["MetalSobelCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1072,15 +1062,13 @@ objc_library( name = "metal_sobel_compute_calculator", srcs = ["MetalSobelComputeCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1090,15 +1078,13 @@ objc_library( srcs = ["MPSSobelCalculator.mm"], copts = ["-std=c++17"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) @@ -1106,15 +1092,13 @@ objc_library( objc_library( name = "mps_threshold_calculator", srcs = ["MPSThresholdCalculator.mm"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index fafdfee8a..c71c02b6d 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -68,7 +68,6 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = ["Accelerate"], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ @@ -90,6 +89,7 @@ objc_library( "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", + "//third_party/apple_frameworks:Accelerate", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -120,13 +120,13 @@ objc_library( ], "//conditions:default": [], }), - sdk_frameworks = [ - "AVFoundation", - "CoreVideo", - "Foundation", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], + deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", + ], ) objc_library( @@ -140,16 +140,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -164,16 +162,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", "//mediapipe/gpu:gl_calculator_helper", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -188,13 +184,11 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Foundation", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", "@com_google_absl//absl/strings", ], ) @@ -211,23 +205,21 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "OpenGLES", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", ":Weakify", ":mediapipe_framework_ios", "//mediapipe/framework:calculator_framework", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:OpenGLES", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) @@ -245,16 +237,6 @@ objc_library( data = [ "testdata/googlelogo_color_272x92dp.png", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", @@ -263,6 +245,14 @@ objc_library( ":mediapipe_framework_ios", ":mediapipe_input_sources_ios", "//mediapipe/calculators/core:pass_through_calculator", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD new file mode 100644 index 000000000..05f830e81 --- /dev/null +++ b/third_party/apple_frameworks/BUILD @@ -0,0 +1,73 @@ +# Build rules to inject Apple Frameworks + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "CoreGraphics", + linkopts = ["-framework CoreGraphics"], +) + +cc_library( + name = "CoreMedia", + linkopts = ["-framework CoreMedia"], +) + +cc_library( + name = "UIKit", + linkopts = ["-framework UIKit"], +) + +cc_library( + name = "Accelerate", + linkopts = ["-framework Accelerate"], +) + +cc_library( + name = "CoreVideo", + linkopts = ["-framework CoreVideo"], +) + +cc_library( + name = "Metal", + linkopts = ["-framework Metal"], +) + +cc_library( + name = "MetalPerformanceShaders", + linkopts = ["-framework MetalPerformanceShaders"], +) + +cc_library( + name = "AVFoundation", + linkopts = ["-framework AVFoundation"], +) + +cc_library( + name = "Foundation", + linkopts = ["-framework Foundation"], +) + +cc_library( + name = "CoreImage", + linkopts = ["-framework CoreImage"], +) + +cc_library( + name = "XCTest", + linkopts = ["-framework XCTest"], +) + +cc_library( + name = "GLKit", + linkopts = ["-framework GLKit"], +) + +cc_library( + name = "OpenGLES", + linkopts = ["-framework OpenGLES"], +) + +cc_library( + name = "QuartzCore", + linkopts = ["-framework QuartzCore"], +) From 576c6da173c4b84b13787c0e6926acab05118880 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 6 Dec 2022 15:22:03 -0800 Subject: [PATCH 178/469] Internal change PiperOrigin-RevId: 493427500 --- mediapipe/tasks/python/audio/BUILD | 4 +- .../tasks/python/audio/audio_classifier.py | 34 ++++++++-- .../tasks/python/audio/audio_embedder.py | 21 ++++-- .../tasks/python/components/processors/BUILD | 9 --- .../python/components/processors/__init__.py | 3 - .../components/processors/embedder_options.py | 68 ------------------- mediapipe/tasks/python/components/utils/BUILD | 5 +- .../components/utils/cosine_similarity.py | 2 - mediapipe/tasks/python/test/audio/BUILD | 2 - .../test/audio/audio_classifier_test.py | 20 ++---- .../python/test/audio/audio_embedder_test.py | 10 +-- mediapipe/tasks/python/test/text/BUILD | 2 - .../python/test/text/text_classifier_test.py | 2 - .../python/test/text/text_embedder_test.py | 10 +-- mediapipe/tasks/python/test/vision/BUILD | 2 - .../test/vision/image_classifier_test.py | 52 +++++--------- .../python/test/vision/image_embedder_test.py | 10 +-- mediapipe/tasks/python/text/BUILD | 4 +- .../tasks/python/text/text_classifier.py | 35 ++++++++-- mediapipe/tasks/python/text/text_embedder.py | 20 ++++-- mediapipe/tasks/python/vision/BUILD | 4 +- .../tasks/python/vision/image_classifier.py | 35 ++++++++-- .../tasks/python/vision/image_embedder.py | 20 ++++-- 23 files changed, 162 insertions(+), 212 deletions(-) delete mode 100644 mediapipe/tasks/python/components/processors/embedder_options.py diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index 2e5815ff0..ce7c5ce08 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -29,11 +29,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -51,11 +51,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index d82b6fe27..cc87d6221 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult _AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options_module.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -62,16 +62,31 @@ class AudioClassifierOptions: mode for running classification on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the classification results asynchronously. - classifier_options: Options for configuring the classifier behavior, such as - score threshold, number of results, etc. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -79,7 +94,12 @@ class AudioClassifierOptions: """Generates an AudioClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _AudioClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index 629e21882..4c37783e9 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ AudioEmbedderResult = embedding_result_module.EmbeddingResult _AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -63,16 +63,22 @@ class AudioEmbedderOptions: stream mode for running embedding extraction on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the embedding results asynchronously. - embedder_options: Options for configuring the embedder behavior, such as - l2_normalize and quantize. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -80,7 +86,8 @@ class AudioEmbedderOptions: """Generates an AudioEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _AudioEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index eef368db0..f87a579b0 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -28,12 +28,3 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) - -py_library( - name = "embedder_options", - srcs = ["embedder_options.py"], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) diff --git a/mediapipe/tasks/python/components/processors/__init__.py b/mediapipe/tasks/python/components/processors/__init__.py index adcb38757..0eb73abe0 100644 --- a/mediapipe/tasks/python/components/processors/__init__.py +++ b/mediapipe/tasks/python/components/processors/__init__.py @@ -15,12 +15,9 @@ """MediaPipe Tasks Components Processors API.""" import mediapipe.tasks.python.components.processors.classifier_options -import mediapipe.tasks.python.components.processors.embedder_options ClassifierOptions = classifier_options.ClassifierOptions -EmbedderOptions = embedder_options.EmbedderOptions # Remove unnecessary modules to avoid duplication in API docs. del classifier_options -del embedder_options del mediapipe diff --git a/mediapipe/tasks/python/components/processors/embedder_options.py b/mediapipe/tasks/python/components/processors/embedder_options.py deleted file mode 100644 index c86a91105..000000000 --- a/mediapipe/tasks/python/components/processors/embedder_options.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Embedder options data class.""" - -import dataclasses -from typing import Any, Optional - -from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions - - -@dataclasses.dataclass -class EmbedderOptions: - """Shared options used by all embedding extraction tasks. - - Attributes: - l2_normalize: Whether to normalize the returned feature vector with L2 norm. - Use this option only if the model does not already contain a native - L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and - L2 norm is thus achieved through TF Lite inference. - quantize: Whether the returned embedding should be quantized to bytes via - scalar quantization. Embeddings are implicitly assumed to be unit-norm and - therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use - the l2_normalize option if this is not the case. - """ - - l2_normalize: Optional[bool] = None - quantize: Optional[bool] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _EmbedderOptionsProto: - """Generates a EmbedderOptions protobuf object.""" - return _EmbedderOptionsProto( - l2_normalize=self.l2_normalize, quantize=self.quantize) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions': - """Creates a `EmbedderOptions` object from the given protobuf object.""" - return EmbedderOptions( - l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, EmbedderOptions): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index b64d04c72..31114f326 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -23,8 +23,5 @@ licenses(["notice"]) py_library( name = "cosine_similarity", srcs = ["cosine_similarity.py"], - deps = [ - "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", - ], + deps = ["//mediapipe/tasks/python/components/containers:embedding_result"], ) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index 486c02ece..ff8979458 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -16,10 +16,8 @@ import numpy as np from mediapipe.tasks.python.components.containers import embedding_result -from mediapipe.tasks.python.components.processors import embedder_options _Embedding = embedding_result.Embedding -_EmbedderOptions = embedder_options.EmbedderOptions def _compute_cosine_similarity(u, v): diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 9278cea55..43f1d417c 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -30,7 +30,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], @@ -48,7 +47,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 0d067e587..75146547c 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' @@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - score_threshold=0.9))) as classifier: + score_threshold=0.9)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['Speech']))) as classifier: + category_allowlist=['Speech'])) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase): r'exclusive options.'): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar'])) + category_allowlist=['foo'], + category_denylist=['bar']) with _AudioClassifier.create_from_options(options) as unused_classifier: pass @@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase): _AudioClassifierOptions( base_options=_BaseOptions( model_asset_path=self.two_heads_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - classifier_options=_ClassifierOptions(max_results=1), + max_results=1, result_callback=save_result) classifier = _AudioClassifier.create_from_options(options) audio_data_list = self._read_wav_file_as_stream(audio_file) diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 2e38ea2ee..f280235d7 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -26,7 +26,6 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -35,7 +34,6 @@ _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions _AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options.EmbedderOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' @@ -172,9 +170,7 @@ class AudioEmbedderTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _AudioEmbedderOptions( - base_options=base_options, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize)) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _AudioEmbedder.create_from_options(options) as embedder: embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) @@ -291,8 +287,8 @@ class AudioEmbedderTest(parameterized.TestCase): options = _AudioEmbedderOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize), + l2_normalize=l2_normalize, + quantize=quantize, result_callback=save_result) with _AudioEmbedder.create_from_options(options) as embedder: diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 38e56bdb2..0e2b06012 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -28,7 +28,6 @@ py_test( deps = [ "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_classifier", @@ -44,7 +43,6 @@ py_test( ], deps = [ "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_embedder", diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index 8678d2194..8df7dce86 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -21,14 +21,12 @@ from absl.testing import parameterized from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_classifier TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category _Classifications = classification_result_module.Classifications _TextClassifier = text_classifier.TextClassifier diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index c9090026c..1346ba373 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -21,13 +21,11 @@ from absl.testing import parameterized import numpy as np from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_embedder _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _TextEmbedder = text_embedder.TextEmbedder _TextEmbedderOptions = text_embedder.TextEmbedderOptions @@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _TextEmbedder.create_from_options(options) # Extracts both embeddings. @@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _TextEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. positive_text0 = "it's a charming and often affecting journey" diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 066107421..48ecc30b3 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -49,7 +49,6 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", @@ -69,7 +68,6 @@ py_test( "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_embedder", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 77f16278f..cbeaf36bd 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier @@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode ImageClassifierResult = classification_result_module.ClassificationResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category _Classifications = classification_result_module.Classifications _Image = image.Image @@ -171,9 +169,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. @@ -200,9 +197,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -212,9 +208,7 @@ class ImageClassifierTest(parameterized.TestCase): def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) - custom_classifier_options = _ClassifierOptions(max_results=1) - options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + options = _ImageClassifierOptions(base_options=base_options, max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -230,11 +224,9 @@ class ImageClassifierTest(parameterized.TestCase): _generate_soccer_ball_results().to_pb2()) def test_score_threshold_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -249,11 +241,9 @@ class ImageClassifierTest(parameterized.TestCase): f'{classification}') def test_max_results_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -263,11 +253,9 @@ class ImageClassifierTest(parameterized.TestCase): len(categories), _MAX_RESULTS, 'Too many results returned.') def test_allow_list_option(self): - custom_classifier_options = _ClassifierOptions( - category_allowlist=_ALLOW_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=_ALLOW_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -280,10 +268,9 @@ class ImageClassifierTest(parameterized.TestCase): f'Label {label} found but not in label allow list') def test_deny_list_option(self): - custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_denylist=_DENY_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -301,19 +288,17 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): - custom_classifier_options = _ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar']) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=['foo'], + category_denylist=['bar']) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): - custom_classifier_options = _ClassifierOptions(score_threshold=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=1) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -386,11 +371,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_for_video(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=4) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( @@ -399,11 +383,10 @@ class ImageClassifierTest(parameterized.TestCase): _generate_burger_results().to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -439,11 +422,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_async_calls_with_illegal_timestamp(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, result_callback=mock.MagicMock()) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(self.test_image, 100) @@ -466,12 +448,11 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions( - max_results=4, score_threshold=threshold) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, + score_threshold=threshold, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): @@ -496,11 +477,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=1, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 4bb96bad6..11c0cf002 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -24,7 +24,6 @@ import numpy as np from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_embedder @@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder @@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _ImageEmbedder.create_from_options(options) image_processing_options = None @@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _ImageEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index 10b4b8a6e..e2a51cdbd 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -28,9 +28,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -47,9 +47,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 9711e8b3a..fdb20f0ef 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,14 +14,14 @@ """MediaPipe text classifier task.""" import dataclasses -from typing import Optional +from typing import Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -30,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _TaskInfo = task_info_module.TaskInfo _CLASSIFICATIONS_STREAM_NAME = 'classifications_out' @@ -46,17 +46,38 @@ class TextClassifierOptions: Attributes: base_options: Base options for the text classifier task. - classifier_options: Options for the text classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. """ base_options: _BaseOptions - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: """Generates an TextClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _TextClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index a9e560ac9..be899636d 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -19,9 +19,9 @@ from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -31,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _TaskInfo = task_info_module.TaskInfo _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' @@ -47,17 +47,25 @@ class TextEmbedderOptions: Attributes: base_options: Base options for the text embedder task. - embedder_options: Options for the text embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. """ base_options: _BaseOptions - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: """Generates an TextEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _TextEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 29e7577e8..241ca4341 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -47,10 +47,10 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -89,9 +89,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 6cbce7860..b60d18e31 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -14,17 +14,17 @@ """MediaPipe image classifier task.""" import dataclasses -from typing import Callable, Mapping, Optional +from typing import Callable, Mapping, Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -63,15 +63,31 @@ class ImageClassifierOptions: objects on single image inputs. 2) The video mode for classifying objects on the decoded frames of a video. 3) The live stream mode for classifying objects on a live stream of input data, such as from camera. - classifier_options: Options for the image classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None @@ -80,7 +96,12 @@ class ImageClassifierOptions: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index a58dca3ae..0bae21bda 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -21,9 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -62,15 +62,22 @@ class ImageEmbedderOptions: image on single image inputs. 2) The video mode for embedding image on the decoded frames of a video. 3) The live stream mode for embedding image on a live stream of input data, such as from camera. - embedder_options: Options for the image embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None @@ -79,7 +86,8 @@ class ImageEmbedderOptions: """Generates an ImageEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _ImageEmbedderGraphOptionsProto( base_options=base_options_proto, From 1167f61f9825cc80e3e81b53b08a59f1a19ef456 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 18:02:35 -0800 Subject: [PATCH 179/469] Remove generic Options template argument from TaskRunner PiperOrigin-RevId: 493462947 --- mediapipe/tasks/web/audio/core/BUILD | 5 +---- .../tasks/web/audio/core/audio_task_runner.ts | 3 +-- mediapipe/tasks/web/core/task_runner.ts | 14 ++++++-------- .../web/text/text_classifier/text_classifier.ts | 2 +- .../tasks/web/text/text_embedder/text_embedder.ts | 2 +- .../tasks/web/vision/core/vision_task_runner.ts | 3 +-- 6 files changed, 11 insertions(+), 18 deletions(-) diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index 9ab6c7bee..cea689838 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -7,8 +7,5 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_task_runner", srcs = ["audio_task_runner.ts"], - deps = [ - "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", - ], + deps = ["//mediapipe/tasks/web/core:task_runner"], ) diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index 00cfe0253..24d78378d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -15,10 +15,9 @@ */ import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Base class for all MediaPipe Audio Tasks. */ -export abstract class AudioTaskRunner extends TaskRunner { +export abstract class AudioTaskRunner extends TaskRunner { private defaultSampleRate = 48000; /** diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index e2ab42e31..71e159dce 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -37,10 +37,9 @@ export class GraphRunnerImageLib extends GraphRunnerImageLibType {} * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ -export async function -createTaskRunner, O extends TaskRunnerOptions>( +export async function createTaskRunner( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset, options: O): Promise { + fileset: WasmFileset, options: TaskRunnerOptions): Promise { const fileLocator: FileLocator = { locateFile() { // The only file loaded with this mechanism is the Wasm binary @@ -61,7 +60,7 @@ createTaskRunner, O extends TaskRunnerOptions>( } /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; @@ -71,10 +70,9 @@ export abstract class TaskRunner { * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ - protected static async createInstance, - O extends TaskRunnerOptions>( + protected static async createInstance( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset, options: O): Promise { + fileset: WasmFileset, options: TaskRunnerOptions): Promise { return createTaskRunner(type, initializeCanvas, fileset, options); } @@ -92,7 +90,7 @@ export abstract class TaskRunner { } /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: O): Promise { + async setOptions(options: TaskRunnerOptions): Promise { if (options.baseOptions) { this.baseOptions = await convertBaseOptionsToProto( options.baseOptions, this.baseOptions); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 8810d4b42..4a8588836 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -41,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 62f9b06db..cd5bc644e 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -45,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 78b4859f2..3432b521b 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -20,8 +20,7 @@ import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends - TaskRunner { +export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ override async setOptions(options: VisionTaskOptions): Promise { await super.setOptions(options); From 402834b4f2236ed2d707f0d20c0ebd2d1a42a721 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Tue, 6 Dec 2022 19:46:33 -0800 Subject: [PATCH 180/469] Internal change PiperOrigin-RevId: 493480322 --- docs/build_model_maker_api_docs.py | 81 ++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 docs/build_model_maker_api_docs.py diff --git a/docs/build_model_maker_api_docs.py b/docs/build_model_maker_api_docs.py new file mode 100644 index 000000000..7732b7d56 --- /dev/null +++ b/docs/build_model_maker_api_docs.py @@ -0,0 +1,81 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe Model Maker reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe-model-maker +$> python build_model_maker_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe_model_maker # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe-model-maker`.') from e + + +PROJECT_SHORT_NAME = 'mediapipe_model_maker' +PROJECT_FULL_NAME = 'MediaPipe Model Maker' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe-model-maker package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe_model_maker)], + base_dir=os.path.dirname(mediapipe_model_maker.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + callbacks=[], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main) From 523d16dffab5d066879b300230cc9ac26ad49128 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 6 Dec 2022 23:54:11 -0800 Subject: [PATCH 181/469] Make GpuBuffer a shared_ptr to a storage collection PiperOrigin-RevId: 493519590 --- mediapipe/gpu/BUILD | 2 + mediapipe/gpu/gpu_buffer.cc | 102 +++++++++++++++++++++--------- mediapipe/gpu/gpu_buffer.h | 105 +++++++++++++++++-------------- mediapipe/gpu/gpu_buffer_test.cc | 22 +++++++ 4 files changed, 156 insertions(+), 75 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index f5cb9f715..009eb3f9e 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -289,7 +289,9 @@ cc_library( deps = [ ":gpu_buffer_format", ":gpu_buffer_storage", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", ":gpu_buffer_storage_image_frame", diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index 388960b11..628e86099 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -3,6 +3,7 @@ #include #include +#include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/port/logging.h" @@ -25,57 +26,101 @@ struct StorageTypeFormatter { } // namespace std::string GpuBuffer::DebugString() const { - return absl::StrCat("GpuBuffer[", - absl::StrJoin(storages_, ", ", StorageTypeFormatter()), - "]"); + return holder_ ? absl::StrCat("GpuBuffer[", width(), "x", height(), " ", + format(), " as ", holder_->DebugString(), "]") + : "GpuBuffer[invalid]"; } -internal::GpuBufferStorage* GpuBuffer::GetStorageForView( +std::string GpuBuffer::StorageHolder::DebugString() const { + absl::MutexLock lock(&mutex_); + return absl::StrJoin(storages_, ", ", StorageTypeFormatter()); +} + +internal::GpuBufferStorage* GpuBuffer::StorageHolder::GetStorageForView( TypeId view_provider_type, bool for_writing) const { - const std::shared_ptr* chosen_storage = nullptr; + std::shared_ptr chosen_storage; + std::function()> conversion; - // First see if any current storage supports the view. - for (const auto& s : storages_) { - if (s->can_down_cast_to(view_provider_type)) { - chosen_storage = &s; - break; - } - } - - // Then try to convert existing storages to one that does. - // TODO: choose best conversion. - if (!chosen_storage) { + { + absl::MutexLock lock(&mutex_); + // First see if any current storage supports the view. for (const auto& s : storages_) { - if (auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider( - view_provider_type, s->storage_type())) { - if (auto new_storage = converter(s)) { - storages_.push_back(new_storage); - chosen_storage = &storages_.back(); + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + + // Then try to convert existing storages to one that does. + // TODO: choose best conversion. + if (!chosen_storage) { + for (const auto& s : storages_) { + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + conversion = absl::bind_front(converter, s); break; } } } } + // Avoid invoking a converter or factory while holding the mutex. + // Two reasons: + // 1. Readers that don't need a conversion will not be blocked. + // 2. We use mutexes to make sure GL contexts are not used simultaneously on + // different threads, and we also rely on Mutex's deadlock detection + // heuristic, which enforces a consistent mutex acquisition order. + // This function is likely to be called within a GL context, and the + // conversion function may in turn use a GL context, and this may cause a + // false positive in the deadlock detector. + // TODO: we could use Mutex::ForgetDeadlockInfo instead. + if (conversion) { + auto new_storage = conversion(); + absl::MutexLock lock(&mutex_); + // Another reader might have already completed and inserted the same + // conversion. TODO: prevent this? + for (const auto& s : storages_) { + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + if (!chosen_storage) { + storages_.push_back(std::move(new_storage)); + chosen_storage = storages_.back(); + } + } + if (for_writing) { + // This will temporarily hold storages to be released, and do so while the + // lock is not held (see above). + decltype(storages_) old_storages; + using std::swap; if (chosen_storage) { // Discard all other storages. - storages_ = {*chosen_storage}; - chosen_storage = &storages_.back(); + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); + storages_ = {chosen_storage}; } else { // Allocate a new storage supporting the requested view. if (auto factory = internal::GpuBufferStorageRegistry::Get() .StorageFactoryForViewProvider(view_provider_type)) { - if (auto new_storage = factory(width(), height(), format())) { + if (auto new_storage = factory(width_, height_, format_)) { + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); storages_ = {std::move(new_storage)}; - chosen_storage = &storages_.back(); + chosen_storage = storages_.back(); } } } } - return chosen_storage ? chosen_storage->get() : nullptr; + + // It is ok to return a non-owning storage pointer here because this object + // ensures the storage's lifetime. Overwriting a GpuBuffer while readers are + // active would violate this, but it's not allowed in MediaPipe. + return chosen_storage ? chosen_storage.get() : nullptr; } internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( @@ -84,8 +129,7 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " - << absl::StrJoin(storages_, ", ", - StorageTypeFormatter()); + << (holder_ ? holder_->DebugString() : "invalid"); DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); return *chosen_storage; } diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 56507d92f..b9a88aa53 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -15,9 +15,12 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_H_ +#include +#include #include #include +#include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" @@ -56,8 +59,7 @@ class GpuBuffer { // Creates an empty buffer of a given size and format. It will be allocated // when a view is requested. GpuBuffer(int width, int height, Format format) - : GpuBuffer(std::make_shared(width, height, - format)) {} + : holder_(std::make_shared(width, height, format)) {} // Copy and move constructors and assignment operators are supported. GpuBuffer(const GpuBuffer& other) = default; @@ -70,9 +72,8 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer(std::shared_ptr storage) { - storages_.push_back(std::move(storage)); - } + explicit GpuBuffer(std::shared_ptr storage) + : holder_(std::make_shared(std::move(storage))) {} #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // This is used to support backward-compatible construction of GpuBuffer from @@ -84,9 +85,11 @@ class GpuBuffer { : GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {} #endif // !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - int width() const { return current_storage().width(); } - int height() const { return current_storage().height(); } - GpuBufferFormat format() const { return current_storage().format(); } + int width() const { return holder_ ? holder_->width() : 0; } + int height() const { return holder_ ? holder_->height() : 0; } + GpuBufferFormat format() const { + return holder_ ? holder_->format() : GpuBufferFormat::kUnknown; + } // Converts to true iff valid. explicit operator bool() const { return operator!=(nullptr); } @@ -122,31 +125,17 @@ class GpuBuffer { // using views. template std::shared_ptr internal_storage() const { - for (const auto& s : storages_) - if (s->down_cast()) return std::static_pointer_cast(s); - return nullptr; + return holder_ ? holder_->internal_storage() : nullptr; } std::string DebugString() const; private: - class PlaceholderGpuBufferStorage - : public internal::GpuBufferStorageImpl { - public: - PlaceholderGpuBufferStorage(int width, int height, Format format) - : width_(width), height_(height), format_(format) {} - int width() const override { return width_; } - int height() const override { return height_; } - GpuBufferFormat format() const override { return format_; } - - private: - int width_ = 0; - int height_ = 0; - GpuBufferFormat format_ = GpuBufferFormat::kUnknown; - }; - internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, - bool for_writing) const; + bool for_writing) const { + return holder_ ? holder_->GetStorageForView(view_provider_type, for_writing) + : nullptr; + } internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, bool for_writing) const; @@ -158,25 +147,49 @@ class GpuBuffer { .template down_cast(); } - std::shared_ptr& no_storage() const { - static auto placeholder = - std::static_pointer_cast( - std::make_shared( - 0, 0, GpuBufferFormat::kUnknown)); - return placeholder; - } + // This class manages a set of alternative storages for the contents of a + // GpuBuffer. GpuBuffer was originally designed as a reference-type object, + // where a copy represents another reference to the same contents, so multiple + // GpuBuffer instances can share the same StorageHolder. + class StorageHolder { + public: + explicit StorageHolder(std::shared_ptr storage) + : StorageHolder(storage->width(), storage->height(), + storage->format()) { + storages_.push_back(std::move(storage)); + } + explicit StorageHolder(int width, int height, Format format) + : width_(width), height_(height), format_(format) {} - const internal::GpuBufferStorage& current_storage() const { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + int width() const { return width_; } + int height() const { return height_; } + GpuBufferFormat format() const { return format_; } - internal::GpuBufferStorage& current_storage() { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, + bool for_writing) const; - // This is mutable because view methods that do not change the contents may - // still need to allocate new storages. - mutable std::vector> storages_; + template + std::shared_ptr internal_storage() const { + absl::MutexLock lock(&mutex_); + for (const auto& s : storages_) + if (s->down_cast()) return std::static_pointer_cast(s); + return nullptr; + } + + std::string DebugString() const; + + private: + int width_ = 0; + int height_ = 0; + GpuBufferFormat format_ = GpuBufferFormat::kUnknown; + // This is mutable because view methods that do not change the contents may + // still need to allocate new storages. + mutable absl::Mutex mutex_; + mutable std::vector> storages_ + ABSL_GUARDED_BY(mutex_); + }; + + std::shared_ptr holder_; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); @@ -184,15 +197,15 @@ class GpuBuffer { }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { - return storages_.empty(); + return holder_ == other; } inline bool GpuBuffer::operator==(const GpuBuffer& other) const { - return storages_ == other.storages_; + return holder_ == other.holder_; } inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) { - storages_.clear(); + holder_ = other; return *this; } diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 145b71806..e4be617db 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -20,6 +20,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" @@ -228,5 +229,26 @@ TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { EXPECT_TRUE(true); } +TEST_F(GpuBufferTest, CopiesShareConversions) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + GpuBuffer other_handle = buffer; + RunInGlContext([&buffer] { + TempGlFramebuffer fb; + auto view = buffer.GetReadView(0); + }); + + // Check that other_handle also sees the same GlTextureBuffer as buffer. + // Note that this is deliberately written so that it still passes on platforms + // where we use another storage for GL textures (they will both be null). + // TODO: expose more accessors for testing? + EXPECT_EQ(other_handle.internal_storage(), + buffer.internal_storage()); +} + } // anonymous namespace } // namespace mediapipe From aad797197bbb4c4170cd21c6baf18084bee84446 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 07:14:46 -0800 Subject: [PATCH 182/469] TensorV1 EGL.h include fix. PiperOrigin-RevId: 493596083 --- mediapipe/framework/formats/tensor.h | 5 ++--- mediapipe/framework/formats/tensor_ahwb.cc | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index ecd63c8c6..3ed72c6fd 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -39,10 +39,9 @@ #endif // MEDIAPIPE_NO_JNI #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include #include - -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index b11f6b55b..90d89c40a 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -4,12 +4,13 @@ #include "mediapipe/framework/formats/tensor.h" #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include + #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gl_base.h" -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB namespace mediapipe { From d9688b769f5207aff13bd782d94dd4d2ad8dcd92 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Wed, 7 Dec 2022 08:13:51 -0800 Subject: [PATCH 183/469] Hide internal APIs from mediapipe pip package's API docs. PiperOrigin-RevId: 493607984 --- .../tasks/python/core/optional_dependencies.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/python/core/optional_dependencies.py b/mediapipe/tasks/python/core/optional_dependencies.py index d4f6a6abc..b1a0ed538 100644 --- a/mediapipe/tasks/python/core/optional_dependencies.py +++ b/mediapipe/tasks/python/core/optional_dependencies.py @@ -13,6 +13,13 @@ # limitations under the License. """MediaPipe Tasks' common but optional dependencies.""" -doc_controls = lambda: None -no_op = lambda x: x -setattr(doc_controls, 'do_not_generate_docs', no_op) +# TensorFlow isn't a dependency of mediapipe pip package. It's only +# required in the API docgen pipeline so we'll ignore it if tensorflow is not +# installed. +try: + from tensorflow.tools.docs import doc_controls +except ModuleNotFoundError: + # Replace the real doc_controls.do_not_generate_docs with an no-op + doc_controls = lambda: None + no_op = lambda x: x + setattr(doc_controls, 'do_not_generate_docs', no_op) From d84eec387bb277b4f379f360d97bbf734cb3ae13 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 10:50:12 -0800 Subject: [PATCH 184/469] Add missing import to InferenceCalculator.proto PiperOrigin-RevId: 493649869 --- mediapipe/calculators/tensor/inference_calculator.proto | 1 + mediapipe/tasks/web/BUILD | 3 --- mediapipe/tasks/web/rollup.config.mjs | 6 ------ package.json | 1 - yarn.lock | 8 -------- 5 files changed, 1 insertion(+), 18 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 46552803b..78a0039bc 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; option java_package = "com.google.mediapipe.calculator.proto"; option java_outer_classname = "InferenceCalculatorProto"; diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 20e717433..bc9e84147 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -44,7 +44,6 @@ rollup_bundle( ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -88,7 +87,6 @@ rollup_bundle( ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -132,7 +130,6 @@ rollup_bundle( ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], diff --git a/mediapipe/tasks/web/rollup.config.mjs b/mediapipe/tasks/web/rollup.config.mjs index e633bf702..3b5119530 100644 --- a/mediapipe/tasks/web/rollup.config.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,15 +1,9 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; -import replace from '@rollup/plugin-replace'; import terser from '@rollup/plugin-terser'; export default { plugins: [ - // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 - replace({ - 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', - delimiters: ['', ''] - }), resolve(), commonjs(), terser() diff --git a/package.json b/package.json index 22a035b74..6ad0b52c0 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,6 @@ "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", - "@rollup/plugin-replace": "^5.0.1", "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", "@types/offscreencanvas": "^2019.7.0", diff --git a/yarn.lock b/yarn.lock index 19c32e322..91b50456e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -148,14 +148,6 @@ is-module "^1.0.0" resolve "^1.22.1" -"@rollup/plugin-replace@^5.0.1": - version "5.0.1" - resolved "https://registry.yarnpkg.com/@rollup/plugin-replace/-/plugin-replace-5.0.1.tgz#49a57af3e6df111a9e75dea3f3572741f4c5c83e" - integrity sha512-Z3MfsJ4CK17BfGrZgvrcp/l6WXoKb0kokULO+zt/7bmcyayokDaQ2K3eDJcRLCTAlp5FPI4/gz9MHAsosz4Rag== - dependencies: - "@rollup/pluginutils" "^5.0.1" - magic-string "^0.26.4" - "@rollup/plugin-terser@^0.1.0": version "0.1.0" resolved "https://registry.yarnpkg.com/@rollup/plugin-terser/-/plugin-terser-0.1.0.tgz#7530c0f11667637419d71820461646c418526041" From 80c605459c2361840c1c0eab05dfa260d7dcfedc Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 11:24:15 -0800 Subject: [PATCH 185/469] Open up framework visibility. PiperOrigin-RevId: 493660013 --- mediapipe/framework/deps/BUILD | 23 ++++---------- mediapipe/framework/port/BUILD | 42 +------------------------ mediapipe/framework/profiler/BUILD | 5 ++- mediapipe/framework/tool/testdata/BUILD | 7 +++-- 4 files changed, 13 insertions(+), 64 deletions(-) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 95ab21707..27bc105c8 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,7 +20,9 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) bzl_library( name = "expand_template_bzl", @@ -50,13 +52,11 @@ mediapipe_proto_library( cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], ) cc_library( name = "cleanup", hdrs = ["cleanup.h"], - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/base:core_headers"], ) @@ -86,7 +86,6 @@ cc_library( # Use this library through "mediapipe/framework/port:gtest_main". visibility = [ "//mediapipe/framework/port:__pkg__", - "//third_party/visionai/algorithms/tracking:__pkg__", ], deps = [ "//mediapipe/framework/port:core_proto", @@ -108,7 +107,6 @@ cc_library( name = "file_helpers", srcs = ["file_helpers.cc"], hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":file_path", "//mediapipe/framework/port:status", @@ -134,7 +132,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:opencv_imgproc", ], @@ -151,7 +148,9 @@ cc_library( cc_library( name = "mathutil", hdrs = ["mathutil.h"], - visibility = ["//visibility:public"], + visibility = [ + "//mediapipe:__subpackages__", + ], deps = [ "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -171,7 +170,6 @@ cc_library( cc_library( name = "no_destructor", hdrs = ["no_destructor.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -190,7 +188,6 @@ cc_library( cc_library( name = "random", hdrs = ["random_base.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/port:integral_types"], ) @@ -211,14 +208,12 @@ cc_library( name = "registration_token", srcs = ["registration_token.cc"], hdrs = ["registration_token.h"], - visibility = ["//visibility:public"], ) cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:logging", @@ -279,7 +274,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], ) cc_library( @@ -310,7 +304,6 @@ cc_library( cc_library( name = "thread_options", hdrs = ["thread_options.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -356,7 +349,6 @@ cc_library( cc_test( name = "mathutil_unittest", srcs = ["mathutil_unittest.cc"], - visibility = ["//visibility:public"], deps = [ ":mathutil", "//mediapipe/framework/port:benchmark", @@ -368,7 +360,6 @@ cc_test( name = "registration_token_test", srcs = ["registration_token_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:gtest_main", @@ -381,7 +372,6 @@ cc_test( timeout = "long", srcs = ["safe_int_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":intops", "//mediapipe/framework/port:gtest_main", @@ -393,7 +383,6 @@ cc_test( name = "monotonic_clock_test", srcs = ["monotonic_clock_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":clock", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index e499ca3a6..1039dc1c6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-parse_headers"], ) @@ -28,7 +28,6 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "1", }, - visibility = ["//visibility:public"], ) #TODO : remove from OSS. @@ -37,13 +36,11 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "0", }, - visibility = ["//visibility:public"], ) cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:aligned_malloc_and_free", "@com_google_absl//absl/base:core_headers", @@ -57,7 +54,6 @@ cc_library( "advanced_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":advanced_proto_lite", ":core_proto", @@ -72,7 +68,6 @@ cc_library( "advanced_proto_lite_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", "//mediapipe/framework:port", @@ -83,7 +78,6 @@ cc_library( cc_library( name = "any_proto", hdrs = ["any_proto.h"], - visibility = ["//visibility:public"], deps = [ ":core_proto", ], @@ -94,7 +88,6 @@ cc_library( hdrs = [ "commandlineflags.h", ], - visibility = ["//visibility:public"], deps = [ "//third_party:glog", "@com_google_absl//absl/flags:flag", @@ -107,7 +100,6 @@ cc_library( "core_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_protobuf//:protobuf", @@ -117,7 +109,6 @@ cc_library( cc_library( name = "file_helpers", hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":status", "//mediapipe/framework/deps:file_helpers", @@ -128,7 +119,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "//mediapipe/framework/deps:image_resizer", @@ -140,14 +130,12 @@ cc_library( cc_library( name = "integral_types", hdrs = ["integral_types.h"], - visibility = ["//visibility:public"], ) cc_library( name = "benchmark", testonly = 1, hdrs = ["benchmark.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_benchmark//:benchmark", ], @@ -158,7 +146,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:re2", ], @@ -173,7 +160,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -190,7 +176,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -204,7 +189,6 @@ cc_library( hdrs = [ "logging.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//third_party:glog", @@ -217,7 +201,6 @@ cc_library( hdrs = [ "map_util.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:map_util", @@ -227,7 +210,6 @@ cc_library( cc_library( name = "numbers", hdrs = ["numbers.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:numbers"], ) @@ -238,13 +220,11 @@ config_setting( define_values = { "MEDIAPIPE_DISABLE_OPENCV": "1", }, - visibility = ["//visibility:public"], ) cc_library( name = "opencv_core", hdrs = ["opencv_core_inc.h"], - visibility = ["//visibility:public"], deps = [ "//third_party:opencv", ], @@ -253,7 +233,6 @@ cc_library( cc_library( name = "opencv_imgproc", hdrs = ["opencv_imgproc_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -263,7 +242,6 @@ cc_library( cc_library( name = "opencv_imgcodecs", hdrs = ["opencv_imgcodecs_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -273,7 +251,6 @@ cc_library( cc_library( name = "opencv_highgui", hdrs = ["opencv_highgui_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -283,7 +260,6 @@ cc_library( cc_library( name = "opencv_video", hdrs = ["opencv_video_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -294,7 +270,6 @@ cc_library( cc_library( name = "opencv_features2d", hdrs = ["opencv_features2d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -304,7 +279,6 @@ cc_library( cc_library( name = "opencv_calib3d", hdrs = ["opencv_calib3d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -314,7 +288,6 @@ cc_library( cc_library( name = "opencv_videoio", hdrs = ["opencv_videoio_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -328,7 +301,6 @@ cc_library( "parse_text_proto.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", ":logging", @@ -339,14 +311,12 @@ cc_library( cc_library( name = "point", hdrs = ["point2.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:point"], ) cc_library( name = "port", hdrs = ["port.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/base:core_headers", @@ -356,14 +326,12 @@ cc_library( cc_library( name = "rectangle", hdrs = ["rectangle.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:rectangle"], ) cc_library( name = "ret_check", hdrs = ["ret_check.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:ret_check", @@ -373,7 +341,6 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:singleton"], ) @@ -382,7 +349,6 @@ cc_library( hdrs = [ "source_location.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:source_location", @@ -397,7 +363,6 @@ cc_library( "status_builder.h", "status_macros.h", ], - visibility = ["//visibility:public"], deps = [ ":source_location", "//mediapipe/framework:port", @@ -412,7 +377,6 @@ cc_library( hdrs = [ "statusor.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/status:statusor", @@ -423,7 +387,6 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], - visibility = ["//visibility:private"], deps = [ ":status", "@com_google_googletest//:gtest", @@ -433,7 +396,6 @@ cc_library( cc_library( name = "threadpool", hdrs = ["threadpool.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [":threadpool_impl_default_to_google"], "//mediapipe:android": [":threadpool_impl_default_to_mediapipe"], @@ -460,7 +422,6 @@ alias( cc_library( name = "topologicalsorter", hdrs = ["topologicalsorter.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:topologicalsorter", @@ -470,6 +431,5 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:vector"], ) diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index b53a1ac39..2947b9844 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -140,7 +140,7 @@ cc_library( name = "circular_buffer", hdrs = ["circular_buffer.h"], visibility = [ - "//visibility:public", + "//mediapipe:__subpackages__", ], deps = [ "//mediapipe/framework/port:integral_types", @@ -151,7 +151,6 @@ cc_test( name = "circular_buffer_test", size = "small", srcs = ["circular_buffer_test.cc"], - visibility = ["//visibility:public"], deps = [ ":circular_buffer", "//mediapipe/framework/port:gtest_main", @@ -164,7 +163,7 @@ cc_library( name = "trace_buffer", srcs = ["trace_buffer.h"], hdrs = ["trace_buffer.h"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework/profiler:__subpackages__"], deps = [ ":circular_buffer", "//mediapipe/framework:calculator_profile_cc_proto", diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index 906688520..f9aab7b72 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -20,7 +20,9 @@ load( licenses(["notice"]) -package(default_visibility = ["//mediapipe:__subpackages__"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) filegroup( name = "test_graph", @@ -40,7 +42,6 @@ mediapipe_simple_subgraph( testonly = 1, graph = "dub_quad_test_subgraph.pbtxt", register_as = "DubQuadTestSubgraph", - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:test_calculators", ], @@ -51,7 +52,7 @@ mediapipe_simple_subgraph( testonly = 1, graph = "nested_test_subgraph.pbtxt", register_as = "NestedTestSubgraph", - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":dub_quad_test_subgraph", "//mediapipe/framework:test_calculators", From 3c0ddf16b4c2b04cfff07d0db0aba48411468e9c Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 11:37:04 -0800 Subject: [PATCH 186/469] Fix an incorrect model sanity check in the object detector graph. PiperOrigin-RevId: 493663592 --- .../tasks/cc/vision/object_detector/object_detector_graph.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index f5dc7e061..a1625c16c 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -532,8 +532,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); - if (model.subgraphs()->size() != 1 || - (*model.subgraphs())[0]->outputs()->size() != 4) { + if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Expected a model with a single subgraph, found %d.", From 2811e0c5c81e0ac7d39eab8c32efbe694de45940 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 12:13:25 -0800 Subject: [PATCH 187/469] Open Source the first set of MediaPipe Tasks tests for Web PiperOrigin-RevId: 493673279 --- mediapipe/framework/port/build_config.bzl | 2 + .../tasks/web/components/processors/BUILD | 76 ++++ .../processors/base_options.test.ts | 127 ++++++ .../processors/classifier_options.test.ts | 114 +++++ .../processors/classifier_result.test.ts | 80 ++++ .../processors/embedder_options.test.ts | 93 ++++ .../processors/embedder_result.test.ts | 75 ++++ mediapipe/tasks/web/components/utils/BUILD | 16 + .../utils/cosine_similarity.test.ts | 85 ++++ mediapipe/tasks/web/core/BUILD | 33 ++ mediapipe/tasks/web/core/task_runner.ts | 7 +- mediapipe/tasks/web/core/task_runner_test.ts | 107 +++++ .../tasks/web/core/task_runner_test_utils.ts | 113 +++++ package.json | 5 + tsconfig.json | 2 +- yarn.lock | 419 ++++++++++++++++-- 16 files changed, 1308 insertions(+), 46 deletions(-) create mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_result.test.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_result.test.ts create mode 100644 mediapipe/tasks/web/components/utils/cosine_similarity.test.ts create mode 100644 mediapipe/tasks/web/core/task_runner_test.ts create mode 100644 mediapipe/tasks/web/core/task_runner_test_utils.ts diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index eaabda856..94a4a5646 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -228,6 +228,8 @@ def mediapipe_ts_library( srcs = srcs, visibility = visibility, deps = deps + [ + "@npm//@types/jasmine", + "@npm//@types/node", "@npm//@types/offscreencanvas", "@npm//@types/google-protobuf", ], diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 86e743928..148a08238 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -13,6 +14,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_options_test_lib", + testonly = True, + srcs = ["classifier_options.test.ts"], + deps = [ + ":classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +jasmine_node_test( + name = "classifier_options_test", + deps = [":classifier_options_test_lib"], +) + mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], @@ -22,6 +39,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_result_test_lib", + testonly = True, + srcs = ["classifier_result.test.ts"], + deps = [ + ":classifier_result", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + ], +) + +jasmine_node_test( + name = "classifier_result_test", + deps = [":classifier_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_result", srcs = ["embedder_result.ts"], @@ -31,6 +64,21 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_result_test_lib", + testonly = True, + srcs = ["embedder_result.test.ts"], + deps = [ + ":embedder_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + ], +) + +jasmine_node_test( + name = "embedder_result_test", + deps = [":embedder_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_options", srcs = ["embedder_options.ts"], @@ -40,6 +88,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_options_test_lib", + testonly = True, + srcs = ["embedder_options.test.ts"], + deps = [ + ":embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +jasmine_node_test( + name = "embedder_options_test", + deps = [":embedder_options_test_lib"], +) + mediapipe_ts_library( name = "base_options", srcs = [ @@ -53,3 +117,15 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "base_options_test_lib", + testonly = True, + srcs = ["base_options.test.ts"], + deps = [":base_options"], +) + +jasmine_node_test( + name = "base_options_test", + deps = [":base_options_test_lib"], +) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts new file mode 100644 index 000000000..46c2277e9 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/base_options.test.ts @@ -0,0 +1,127 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +// Placeholder for internal dependency on trusted resource URL builder + +import {convertBaseOptionsToProto} from './base_options'; + +describe('convertBaseOptionsToProto()', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + }); + + it('verifies that at least one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({})) + .toBeRejectedWithError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({ + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + })) + .toBeRejectedWithError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('downloads model', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetPath: `foo`, + }); + + expect(fetchSpy).toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable CPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'cpu', + }); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + expect(baseOptionsProto.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + let baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + // Clear backend + baseOptionsProto = + await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_options.test.ts b/mediapipe/tasks/web/components/processors/classifier_options.test.ts new file mode 100644 index 000000000..928bda426 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_options.test.ts @@ -0,0 +1,114 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +import {convertClassifierOptionsToProto} from './classifier_options'; + +interface TestCase { + optionName: keyof ClassifierOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertClassifierOptionsToProto()', () => { + function verifyOption( + actualClassifierOptions: ClassifierOptionsProto, + expectedClassifierOptions: Record = {}): void { + expect(actualClassifierOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedClassifierOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + + classifierOptionsProto = + convertClassifierOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + classifierOptionsProto, + {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {maxResults: 2}, classifierOptionsProto); + verifyOption(classifierOptionsProto, {'maxResults': 2}); + }); + + it('merges options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {displayNamesLocale: 'en'}, classifierOptionsProto); + verifyOption( + classifierOptionsProto, {'maxResults': 1, 'displayNamesLocale': 'en'}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_result.test.ts b/mediapipe/tasks/web/components/processors/classifier_result.test.ts new file mode 100644 index 000000000..4b93d0a76 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_result.test.ts @@ -0,0 +1,80 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; + +import {convertFromClassificationResultProto} from './classifier_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromClassificationResultProto()', () => { + it('transforms custom values', () => { + const classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 2, + score: 0.3, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 1 + }); + }); + + it('transforms default values', () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{index: 0, score: 0, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + }); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_options.test.ts b/mediapipe/tasks/web/components/processors/embedder_options.test.ts new file mode 100644 index 000000000..b879a6b29 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.test.ts @@ -0,0 +1,93 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +import {convertEmbedderOptionsToProto} from './embedder_options'; + +interface TestCase { + optionName: keyof EmbedderOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertEmbedderOptionsToProto()', () => { + function verifyOption( + actualEmbedderOptions: EmbedderOptionsProto, + expectedEmbedderOptions: Record = {}): void { + expect(actualEmbedderOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedEmbedderOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'l2Normalize', + protoName: 'l2Normalize', + customValue: true, + defaultValue: undefined + }, + { + optionName: 'quantize', + protoName: 'quantize', + customValue: true, + defaultValue: undefined + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + + embedderOptionsProto = + convertEmbedderOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let embedderOptionsProto = + convertEmbedderOptionsToProto({l2Normalize: true}); + verifyOption(embedderOptionsProto, {'l2Normalize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: false}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': false}); + }); + + it('replaces options', () => { + let embedderOptionsProto = convertEmbedderOptionsToProto({quantize: true}); + verifyOption(embedderOptionsProto, {'quantize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: true}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': true, 'quantize': true}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_result.test.ts b/mediapipe/tasks/web/components/processors/embedder_result.test.ts new file mode 100644 index 000000000..97ba935c8 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.test.ts @@ -0,0 +1,75 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; + +import {convertFromEmbeddingResultProto} from './embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromEmbeddingResultProto()', () => { + it('transforms custom values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + expect(timestampMs).toEqual(1); + }); + + it('transforms custom quantized values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + expect(timestampMs).toEqual(1); + }); +}); diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD index 1c1ba69ca..f4a215e48 100644 --- a/mediapipe/tasks/web/components/utils/BUILD +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -1,4 +1,5 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -9,3 +10,18 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", ], ) + +mediapipe_ts_library( + name = "cosine_similarity_test_lib", + testonly = True, + srcs = ["cosine_similarity.test.ts"], + deps = [ + ":cosine_similarity", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +jasmine_node_test( + name = "cosine_similarity_test", + deps = [":cosine_similarity_test_lib"], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts new file mode 100644 index 000000000..f442caa20 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts @@ -0,0 +1,85 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +import {computeCosineSimilarity} from './cosine_similarity'; + +describe('computeCosineSimilarity', () => { + it('fails with quantized and float embeddings', () => { + const u: Embedding = {floatEmbedding: [1.0], headIndex: 0, headName: ''}; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([1.0]), + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between quantized and float embeddings/); + }); + + it('fails with zero norm', () => { + const u = {floatEmbedding: [0.0], headIndex: 0, headName: ''}; + expect(() => computeCosineSimilarity(u, u)) + .toThrowError( + /Cannot compute cosine similarity on embedding with 0 norm/); + }); + + it('fails with different sizes', () => { + const u: + Embedding = {floatEmbedding: [1.0, 2.0], headIndex: 0, headName: ''}; + const v: Embedding = { + floatEmbedding: [1.0, 2.0, 3.0], + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between embeddings of different sizes/); + }); + + it('succeeds with float embeddings', () => { + const u: Embedding = { + floatEmbedding: [1.0, 0.0, 0.0, 0.0], + headIndex: 0, + headName: '' + }; + const v: Embedding = { + floatEmbedding: [0.5, 0.5, 0.5, 0.5], + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(0.5); + }); + + it('succeeds with quantized embeddings', () => { + const u: Embedding = { + quantizedEmbedding: new Uint8Array([255, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([0, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(-1.0); + }); +}); diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index be1b71f5d..1721661f5 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,6 +1,7 @@ # This package contains options shared by all MediaPipe Tasks for Web. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -32,6 +33,38 @@ mediapipe_ts_library( deps = [":core"], ) +mediapipe_ts_library( + name = "task_runner_test_utils", + testonly = True, + srcs = [ + "task_runner_test_utils.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + ], +) + +mediapipe_ts_library( + name = "task_runner_test_lib", + testonly = True, + srcs = [ + "task_runner_test.ts", + ], + deps = [ + ":task_runner", + ":task_runner_test_utils", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "task_runner_test", + deps = [":task_runner_test_lib"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 71e159dce..6712c4d89 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -77,9 +77,10 @@ export abstract class TaskRunner { } constructor( - wasmModule: WasmModule, - glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); + wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + graphRunner?: GraphRunnerImageLib) { + this.graphRunner = + graphRunner ?? new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts new file mode 100644 index 000000000..c9aad9d25 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -0,0 +1,107 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {TaskRunner} from '../../../tasks/web/core/task_runner'; +import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; +import {ErrorListener} from '../../../web/graph_runner/graph_runner'; + +import {GraphRunnerImageLib} from './task_runner'; + +class TaskRunnerFake extends TaskRunner { + protected baseOptions = new BaseOptionsProto(); + private errorListener: ErrorListener|undefined; + private errors: string[] = []; + + static createFake(): TaskRunnerFake { + const wasmModule = createSpyWasmModule(); + return new TaskRunnerFake(wasmModule); + } + + constructor(wasmModuleFake: SpyWasmModule) { + super( + wasmModuleFake, /* glCanvas= */ null, + jasmine.createSpyObj([ + 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', + 'registerModelResourcesGraphService', 'attachErrorListener' + ])); + const graphRunner = this.graphRunner as jasmine.SpyObj; + expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled(); + expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); + graphRunner.attachErrorListener.and.callFake(listener => { + this.errorListener = listener; + }); + graphRunner.setGraph.and.callFake(() => { + this.throwErrors(); + }); + graphRunner.finishProcessing.and.callFake(() => { + this.throwErrors(); + }); + } + + enqueueError(message: string): void { + this.errors.push(message); + } + + override finishProcessing(): void { + super.finishProcessing(); + } + + override setGraph(graphData: Uint8Array, isBinary: boolean): void { + super.setGraph(graphData, isBinary); + } + + private throwErrors(): void { + expect(this.errorListener).toBeDefined(); + for (const error of this.errors) { + this.errorListener!(/* errorCode= */ -1, error); + } + this.errors = []; + } +} + +describe('TaskRunner', () => { + it('handles errors during graph update', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError('Test error'); + }); + + it('handles errors during graph execution', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.finishProcessing(); + }).toThrowError('Test error'); + }); + + it('can handle multiple errors', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error 1'); + taskRunner.enqueueError('Test error 2'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error 1, Test error 2/); + }); +}); diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts new file mode 100644 index 000000000..2a1161a55 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -0,0 +1,113 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; +import {WasmModule} from '../../../web/graph_runner/graph_runner'; +import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service'; + +type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources; + +/** + * Convenience type for our fake WasmModule for Jasmine testing. + */ +export declare type SpyWasmModule = jasmine.SpyObj; + +/** + * Factory function for creating a fake WasmModule for our Jasmine tests, + * allowing our APIs to no longer rely on the Wasm layer so they can run tests + * in pure JS/TS (and optionally spy on the calls). + */ +export function createSpyWasmModule(): SpyWasmModule { + return jasmine.createSpyObj([ + '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', + '_attachProtoVectorListener', '_free', '_waitUntilIdle', + '_addStringToInputStream', '_registerModelResourcesGraphService', + '_configureAudio' + ]); +} + +/** + * Sets up our equality testing to use a custom float equality checking function + * to avoid incorrect test results due to minor floating point inaccuracies. + */ +export function addJasmineCustomFloatEqualityTester() { + jasmine.addCustomEqualityTester((a, b) => { // Custom float equality + if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { + return Math.abs(a - b) < 5e-8; + } + return; + }); +} + +/** The minimum interface provided by a test fake. */ +export interface MediapipeTasksFake { + graph: CalculatorGraphConfig|undefined; + calculatorName: string; + attachListenerSpies: jasmine.Spy[]; +} + +/** An map of field paths to values */ +export type FieldPathToValue = [string[] | string, unknown]; + +/** + * Verifies that the graph has been initialized and that it contains the + * provided options. + */ +export function verifyGraph( + tasksFake: MediapipeTasksFake, + expectedCalculatorOptions?: FieldPathToValue, + expectedBaseOptions?: FieldPathToValue, + ): void { + expect(tasksFake.graph).toBeDefined(); + expect(tasksFake.graph!.getNodeList().length).toBe(1); + const node = tasksFake.graph!.getNodeList()[0].toObject(); + expect(node).toEqual( + jasmine.objectContaining({calculator: tasksFake.calculatorName})); + + if (expectedBaseOptions) { + const [fieldPath, value] = expectedBaseOptions; + let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } + + if (expectedCalculatorOptions) { + const [fieldPath, value] = expectedCalculatorOptions; + let proto = (node.options as {ext: unknown}).ext; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } +} + +/** + * Verifies all listeners (as exposed by `.attachListenerSpies`) have been + * attached at least once since the last call to `verifyListenersRegistered()`. + * This helps us to ensure that listeners are re-registered with every graph + * update. + */ +export function verifyListenersRegistered(tasksFake: MediapipeTasksFake): void { + for (const spy of tasksFake.attachListenerSpies) { + expect(spy.calls.count()).toBeGreaterThanOrEqual(1); + spy.calls.reset(); + } +} diff --git a/package.json b/package.json index 6ad0b52c0..89b62bc83 100644 --- a/package.json +++ b/package.json @@ -3,14 +3,19 @@ "version": "0.0.0-alphga", "description": "MediaPipe GitHub repo", "devDependencies": { + "@bazel/jasmine": "^5.7.2", "@bazel/rollup": "^5.7.1", "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", + "@types/jasmine": "^4.3.1", + "@types/node": "^18.11.11", "@types/offscreencanvas": "^2019.7.0", "google-protobuf": "^3.21.2", + "jasmine": "^4.5.0", + "jasmine-core": "^4.5.0", "protobufjs": "^7.1.2", "protobufjs-cli": "^1.0.2", "rollup": "^2.3.0", diff --git a/tsconfig.json b/tsconfig.json index c17b1902e..970246dbb 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -10,7 +10,7 @@ "inlineSourceMap": true, "inlineSources": true, "strict": true, - "types": ["@types/offscreencanvas"], + "types": ["@types/offscreencanvas", "@types/jasmine", "node"], "rootDirs": [ ".", "./bazel-out/host/bin", diff --git a/yarn.lock b/yarn.lock index 91b50456e..9c4d91d30 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3,34 +3,52 @@ "@babel/parser@^7.9.4": - version "7.20.3" - resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.3.tgz#5358cf62e380cf69efcb87a7bb922ff88bfac6e2" - integrity sha512-OP/s5a94frIPXwjzEcv5S/tpQfc6XhxYUnmWpgdqMWGgYCuErA3SzozaRAMQgSZWKeTJxht9aWAkUY+0UzvOFg== + version "7.20.5" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.5.tgz#7f3c7335fe417665d929f34ae5dceae4c04015e8" + integrity sha512-r27t/cy/m9uKLXQNWWebeCUHgnAZq0CpG1OwKRxzJMP1vpSU4bSIK2hq+/cp0bQxetkXx38n09rNu8jVkcK/zA== + +"@bazel/jasmine@^5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/jasmine/-/jasmine-5.7.2.tgz#438f272e66e939106cbdd58db709cd6aa008131b" + integrity sha512-RJruOB6S9e0efTNIE2JVdaslguUXh5KcmLUCq/xLCt0zENP44ssp9OooDIrZ8H+Sp4mLDNBX7CMMA9WTsbsxTQ== + dependencies: + c8 "~7.5.0" + jasmine-reporters "~2.5.0" "@bazel/rollup@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.1.tgz#6f644c2d493a5bd9cd3724a6f239e609585c6e37" - integrity sha512-LLNogoK2Qx9GIJVywQ+V/czjud8236mnaRX//g7qbOyXoWZDQvAEgsxRHq+lS/XX9USbh+zJJlfb+Dfp/PXx4A== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.2.tgz#9953b06e3de52794791cee4f89540c263b035fcf" + integrity sha512-yGWLheSKdMnJ/Y3/qg+zCDx/qkD04FBFp+BjRS8xP4yvlz9G4rW3zc45VzHHz3oOywgQaY1vhfKuZMCcjTGEyA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" "@bazel/typescript@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682" - integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.2.tgz#a341215dc93ce28794e8430b311756816140bd78" + integrity sha512-tarBJBEIirnq/YaeYu18vXcDxjzlq4xhCXvXUxA0lhHX5oArjEcAEn4tmO0jF+t/7cbkAdMT7daG6vIHSz0QAA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" semver "5.6.0" source-map-support "0.5.9" tsutils "3.21.0" -"@bazel/worker@5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad" - integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg== +"@bazel/worker@5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.2.tgz#43d800dc1b5a3707340a4eb0102da81c53fc6f63" + integrity sha512-H+auDA0QKF4mtZxKkZ2OKJvD7hGXVsVKtvcf4lbb93ur0ldpb5k810PcDxngmIGBcIX5kmyxniNTIiGFNobWTg== dependencies: google-protobuf "^3.6.1" +"@bcoe/v8-coverage@^0.2.3": + version "0.2.3" + resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" + integrity sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw== + +"@istanbuljs/schema@^0.1.2": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@istanbuljs/schema/-/schema-0.1.3.tgz#e45e384e4b8ec16bce2fd903af78450f6bf7ec98" + integrity sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA== + "@jridgewell/gen-mapping@^0.3.0": version "0.3.2" resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz#c1aedc61e853f2bb9f5dfe6d4442d3b565b253b9" @@ -125,9 +143,9 @@ integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw== "@rollup/plugin-commonjs@^23.0.2": - version "23.0.2" - resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.2.tgz#3a3a5b7b1b1cb29037eb4992edcaae997d7ebd92" - integrity sha512-e9ThuiRf93YlVxc4qNIurvv+Hp9dnD+4PjOqQs5vAYfcZ3+AXSrcdzXnVjWxcGQOa6KGJFcRZyUI3ktWLavFjg== + version "23.0.3" + resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.3.tgz#442cd8ccca1b7563a503da86fc84a1a7112b54bb" + integrity sha512-31HxrT5emGfTyIfAs1lDQHj6EfYxTXcwtX5pIIhq+B/xZBNIqQ179d/CkYxlpYmFCxT78AeU4M8aL8Iv/IBxFA== dependencies: "@rollup/pluginutils" "^5.0.1" commondir "^1.0.1" @@ -174,6 +192,21 @@ resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504" integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw== +"@types/is-windows@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@types/is-windows/-/is-windows-1.0.0.tgz#1011fa129d87091e2f6faf9042d6704cdf2e7be0" + integrity sha512-tJ1rq04tGKuIJoWIH0Gyuwv4RQ3+tIu7wQrC0MV47raQ44kIzXSSFKfrxFUOWVRvesoF7mrTqigXmqoZJsXwTg== + +"@types/istanbul-lib-coverage@^2.0.1": + version "2.0.4" + resolved "https://registry.yarnpkg.com/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.4.tgz#8467d4b3c087805d63580480890791277ce35c44" + integrity sha512-z/QT1XN4K4KYuslS23k62yDIDLwLFkzxOuMplDtObz0+y7VqJCaO2o+SPwHCvLFZh7xazvvoor2tA/hPz9ee7g== + +"@types/jasmine@^4.3.1": + version "4.3.1" + resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-4.3.1.tgz#2d8ab5601c2fe7d9673dcb157e03f128ab5c5fff" + integrity sha512-Vu8l+UGcshYmV1VWwULgnV/2RDbBaO6i2Ptx7nd//oJPIZGhoI1YLST4VKagD2Pq/Bc2/7zvtvhM7F3p4SN7kQ== + "@types/linkify-it@*": version "3.0.2" resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9" @@ -192,10 +225,10 @@ resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9" integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA== -"@types/node@>=13.7.0": - version "18.11.9" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" - integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg== +"@types/node@>=13.7.0", "@types/node@^18.11.11": + version "18.11.11" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.11.tgz#1d455ac0211549a8409d3cdb371cd55cc971e8dc" + integrity sha512-KJ021B1nlQUBLopzZmPBVuGU9un7WJd/W4ya7Ih02B4Uwky5Nja0yGYav2EfYIk0RR2Q9oVhf60S2XR1BCWJ2g== "@types/offscreencanvas@^2019.7.0": version "2019.7.0" @@ -207,6 +240,11 @@ resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-1.20.2.tgz#97d26e00cd4a0423b4af620abecf3e6f442b7975" integrity sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q== +"@xmldom/xmldom@^0.8.5": + version "0.8.6" + resolved "https://registry.yarnpkg.com/@xmldom/xmldom/-/xmldom-0.8.6.tgz#8a1524eb5bd5e965c1e3735476f0262469f71440" + integrity sha512-uRjjusqpoqfmRkTaNuLJ2VohVr67Q5YwDATW3VU7PfzTj6IRaihGrYI7zckGZjxQPBIp63nfvJbM+Yu5ICh0Bg== + acorn-jsx@^5.3.2: version "5.3.2" resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" @@ -217,7 +255,12 @@ acorn@^8.5.0, acorn@^8.8.0: resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73" integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA== -ansi-styles@^4.1.0: +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0, ansi-styles@^4.1.0: version "4.3.0" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== @@ -264,6 +307,25 @@ builtin-modules@^3.3.0: resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-3.3.0.tgz#cae62812b89801e9656336e46223e030386be7b6" integrity sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw== +c8@~7.5.0: + version "7.5.0" + resolved "https://registry.yarnpkg.com/c8/-/c8-7.5.0.tgz#a69439ab82848f344a74bb25dc5dd4e867764481" + integrity sha512-GSkLsbvDr+FIwjNSJ8OwzWAyuznEYGTAd1pzb/Kr0FMLuV4vqYJTyjboDTwmlUNAG6jAU3PFWzqIdKrOt1D8tw== + dependencies: + "@bcoe/v8-coverage" "^0.2.3" + "@istanbuljs/schema" "^0.1.2" + find-up "^5.0.0" + foreground-child "^2.0.0" + furi "^2.0.0" + istanbul-lib-coverage "^3.0.0" + istanbul-lib-report "^3.0.0" + istanbul-reports "^3.0.2" + rimraf "^3.0.0" + test-exclude "^6.0.0" + v8-to-istanbul "^7.1.0" + yargs "^16.0.0" + yargs-parser "^20.0.0" + catharsis@^0.9.0: version "0.9.0" resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121" @@ -279,6 +341,15 @@ chalk@^4.0.0: ansi-styles "^4.1.0" supports-color "^7.1.0" +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + color-convert@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" @@ -306,6 +377,20 @@ concat-map@0.0.1: resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg== +convert-source-map@^1.6.0: + version "1.9.0" + resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.9.0.tgz#7faae62353fb4213366d0ca98358d22e8368b05f" + integrity sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A== + +cross-spawn@^7.0.0: + version "7.0.3" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" + integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== + dependencies: + path-key "^3.1.0" + shebang-command "^2.0.0" + which "^2.0.1" + deep-is@~0.1.3: version "0.1.4" resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" @@ -316,11 +401,21 @@ deepmerge@^4.2.2: resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955" integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg== +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + entities@~2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5" integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w== +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + escape-string-regexp@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344" @@ -382,6 +477,22 @@ fast-levenshtein@~2.0.6: resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw== +find-up@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-5.0.0.tgz#4c92819ecb7083561e4f4a240a86be5198f536fc" + integrity sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng== + dependencies: + locate-path "^6.0.0" + path-exists "^4.0.0" + +foreground-child@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/foreground-child/-/foreground-child-2.0.0.tgz#71b32800c9f15aa8f2f83f4a6bd9bff35d861a53" + integrity sha512-dCIq9FpEcyQyXKCkyzmlPTFNgrCzPudOe+mhvJU5zAtlBnGVy2yKxtfsxK2tQBThwq225jcvBjpw1Gr40uzZCA== + dependencies: + cross-spawn "^7.0.0" + signal-exit "^3.0.2" + fs.realpath@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" @@ -397,7 +508,20 @@ function-bind@^1.1.1: resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== -glob@^7.1.3: +furi@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/furi/-/furi-2.0.0.tgz#13d85826a1af21acc691da6254b3888fc39f0b4a" + integrity sha512-uKuNsaU0WVaK/vmvj23wW1bicOFfyqSsAIH71bRZx8kA4Xj+YCHin7CJKJJjkIsmxYaPFLk9ljmjEyB7xF7WvQ== + dependencies: + "@types/is-windows" "^1.0.0" + is-windows "^1.0.2" + +get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob@^7.1.3, glob@^7.1.4, glob@^7.1.6: version "7.2.3" resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== @@ -442,6 +566,11 @@ has@^1.0.3: dependencies: function-bind "^1.1.1" +html-escaper@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-2.0.2.tgz#dfd60027da36a36dfcbe236262c00a5822681453" + integrity sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg== + inflight@^1.0.4: version "1.0.6" resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" @@ -469,6 +598,11 @@ is-core-module@^2.9.0: dependencies: has "^1.0.3" +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + is-module@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/is-module/-/is-module-1.0.0.tgz#3258fb69f78c14d5b815d664336b4cffb6441591" @@ -481,6 +615,59 @@ is-reference@1.2.1: dependencies: "@types/estree" "*" +is-windows@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" + integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== + +isexe@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" + integrity sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw== + +istanbul-lib-coverage@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.0.tgz#189e7909d0a39fa5a3dfad5b03f71947770191d3" + integrity sha512-eOeJ5BHCmHYvQK7xt9GkdHuzuCGS1Y6g9Gvnx3Ym33fz/HpLRYxiS0wHNr+m/MBC8B647Xt608vCDEvhl9c6Mw== + +istanbul-lib-report@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz#7518fe52ea44de372f460a76b5ecda9ffb73d8a6" + integrity sha512-wcdi+uAKzfiGT2abPpKZ0hSU1rGQjUQnLvtY5MpQ7QCTahD3VODhcu4wcfY1YtkGaDD5yuydOLINXsfbus9ROw== + dependencies: + istanbul-lib-coverage "^3.0.0" + make-dir "^3.0.0" + supports-color "^7.1.0" + +istanbul-reports@^3.0.2: + version "3.1.5" + resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-3.1.5.tgz#cc9a6ab25cb25659810e4785ed9d9fb742578bae" + integrity sha512-nUsEMa9pBt/NOHqbcbeJEgqIlY/K7rVWUX6Lql2orY5e9roQOthbR3vtY4zzf2orPELg80fnxxk9zUyPlgwD1w== + dependencies: + html-escaper "^2.0.0" + istanbul-lib-report "^3.0.0" + +jasmine-core@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine-core/-/jasmine-core-4.5.0.tgz#1a6bd0bde3f60996164311c88a0995d67ceda7c3" + integrity sha512-9PMzyvhtocxb3aXJVOPqBDswdgyAeSB81QnLop4npOpbqnheaTEwPc9ZloQeVswugPManznQBjD8kWDTjlnHuw== + +jasmine-reporters@~2.5.0: + version "2.5.2" + resolved "https://registry.yarnpkg.com/jasmine-reporters/-/jasmine-reporters-2.5.2.tgz#b5dfa1d9c40b8020c5225e0e1e2b9953d66a4d69" + integrity sha512-qdewRUuFOSiWhiyWZX8Yx3YNQ9JG51ntBEO4ekLQRpktxFTwUHy24a86zD/Oi2BRTKksEdfWQZcQFqzjqIkPig== + dependencies: + "@xmldom/xmldom" "^0.8.5" + mkdirp "^1.0.4" + +jasmine@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine/-/jasmine-4.5.0.tgz#8d3c0d0a33a61e4d05c9f9747ee5dfaf6f7b5d3d" + integrity sha512-9olGRvNZyADIwYL9XBNBst5BTU/YaePzuddK+YRslc7rI9MdTIE4r3xaBKbv2GEmzYYUfMOdTR8/i6JfLZaxSQ== + dependencies: + glob "^7.1.6" + jasmine-core "^4.5.0" + js2xmlparser@^4.0.2: version "4.0.2" resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a" @@ -531,7 +718,14 @@ linkify-it@^3.0.1: dependencies: uc.micro "^1.0.1" -lodash@^4.17.14, lodash@^4.17.15: +locate-path@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-6.0.0.tgz#55321eb309febbc59c4801d931a72452a681d286" + integrity sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw== + dependencies: + p-locate "^5.0.0" + +lodash@^4.17.15, lodash@^4.17.21: version "4.17.21" resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== @@ -555,6 +749,13 @@ magic-string@^0.26.4: dependencies: sourcemap-codec "^1.4.8" +make-dir@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-3.1.0.tgz#415e967046b3a7f1d185277d84aa58203726a13f" + integrity sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw== + dependencies: + semver "^6.0.0" + markdown-it-anchor@^8.4.1: version "8.6.5" resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8" @@ -572,16 +773,16 @@ markdown-it@^12.3.2: uc.micro "^1.0.5" marked@^4.0.10: - version "4.2.2" - resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.2.tgz#1d2075ad6cdfe42e651ac221c32d949a26c0672a" - integrity sha512-JjBTFTAvuTgANXx82a5vzK9JLSMoV6V3LBVn4Uhdso6t7vXrGx7g1Cd2r6NYSsxrYbQGFCMqBDhFHyK5q2UvcQ== + version "4.2.3" + resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.3.tgz#bd76a5eb510ff1d8421bc6c3b2f0b93488c15bea" + integrity sha512-slWRdJkbTZ+PjkyJnE30Uid64eHwbwa1Q25INCAYfZlK4o6ylagBy/Le9eWntqJFoFT93ikUKMv47GZ4gTwHkw== mdurl@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e" integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g== -minimatch@^3.1.1: +minimatch@^3.0.4, minimatch@^3.1.1: version "3.1.2" resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== @@ -589,9 +790,9 @@ minimatch@^3.1.1: brace-expansion "^1.1.7" minimatch@^5.0.1: - version "5.1.0" - resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7" - integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg== + version "5.1.1" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.1.tgz#6c9dffcf9927ff2a31e74b5af11adf8b9604b022" + integrity sha512-362NP+zlprccbEt/SkxKfRMHnNY85V74mVnpUpNyr3F35covl09Kec7/sEFLt3RA4oXmewtoaanoIf67SE5Y5g== dependencies: brace-expansion "^2.0.1" @@ -624,11 +825,35 @@ optionator@^0.8.1: type-check "~0.3.2" word-wrap "~1.2.3" +p-limit@^3.0.2: + version "3.1.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-3.1.0.tgz#e1daccbe78d0d1388ca18c64fea38e3e57e3706b" + integrity sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ== + dependencies: + yocto-queue "^0.1.0" + +p-locate@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-5.0.0.tgz#83c8315c6785005e3bd021839411c9e110e6d834" + integrity sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw== + dependencies: + p-limit "^3.0.2" + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + path-is-absolute@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== +path-key@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" + integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== + path-parse@^1.0.7: version "1.0.7" resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" @@ -678,12 +903,17 @@ protobufjs@^7.1.2: "@types/node" ">=13.7.0" long "^5.0.0" +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== + requizzle@^0.2.3: - version "0.2.3" - resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded" - integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ== + version "0.2.4" + resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.4.tgz#319eb658b28c370f0c20f968fa8ceab98c13d27c" + integrity sha512-JRrFk1D4OQ4SqovXOgdav+K8EAhSB/LJZqCz8tbX0KObcdeM15Ss59ozWMBWmmINMagCwmqn4ZNryUGpBsl6Jw== dependencies: - lodash "^4.17.14" + lodash "^4.17.21" resolve@^1.22.1: version "1.22.1" @@ -713,6 +943,11 @@ semver@5.6.0: resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004" integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg== +semver@^6.0.0: + version "6.3.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.0.tgz#ee0a64c8af5e8ceea67687b133761e1becbd1d3d" + integrity sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw== + semver@^7.1.2: version "7.3.8" resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798" @@ -720,6 +955,23 @@ semver@^7.1.2: dependencies: lru-cache "^6.0.0" +shebang-command@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" + integrity sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA== + dependencies: + shebang-regex "^3.0.0" + +shebang-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-3.0.0.tgz#ae16f1644d873ecad843b0307b143362d4c42172" + integrity sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A== + +signal-exit@^3.0.2: + version "3.0.7" + resolved "https://registry.yarnpkg.com/signal-exit/-/signal-exit-3.0.7.tgz#a9a1767f8af84155114eaabd73f99273c8f59ad9" + integrity sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ== + source-map-support@0.5.9: version "0.5.9" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f" @@ -741,11 +993,32 @@ source-map@^0.6.0, source-map@~0.6.1: resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== +source-map@^0.7.3: + version "0.7.4" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.4.tgz#a9bbe705c9d8846f4e08ff6765acf0f1b0898656" + integrity sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA== + sourcemap-codec@^1.4.8: version "1.4.8" resolved "https://registry.yarnpkg.com/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz#ea804bd94857402e6992d05a38ef1ae35a9ab4c4" integrity sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA== +string-width@^4.1.0, string-width@^4.2.0: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + strip-json-comments@^3.1.0: version "3.1.1" resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006" @@ -769,15 +1042,24 @@ taffydb@2.6.2: integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA== terser@^5.15.1: - version "5.15.1" - resolved "https://registry.yarnpkg.com/terser/-/terser-5.15.1.tgz#8561af6e0fd6d839669c73b92bdd5777d870ed6c" - integrity sha512-K1faMUvpm/FBxjBXud0LWVAGxmvoPbZbfTCYbSgaaYQaIXI3/TdI7a7ZGA73Zrou6Q8Zmz3oeUTsp/dj+ag2Xw== + version "5.16.1" + resolved "https://registry.yarnpkg.com/terser/-/terser-5.16.1.tgz#5af3bc3d0f24241c7fb2024199d5c461a1075880" + integrity sha512-xvQfyfA1ayT0qdK47zskQgRZeWLoOQ8JQ6mIgRGVNwZKdQMU+5FkCBjmv4QjcrTzyZquRw2FVtlJSRUmMKQslw== dependencies: "@jridgewell/source-map" "^0.3.2" acorn "^8.5.0" commander "^2.20.0" source-map-support "~0.5.20" +test-exclude@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/test-exclude/-/test-exclude-6.0.0.tgz#04a8698661d805ea6fa293b6cb9e63ac044ef15e" + integrity sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w== + dependencies: + "@istanbuljs/schema" "^0.1.2" + glob "^7.1.4" + minimatch "^3.0.4" + tmp@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14" @@ -812,9 +1094,9 @@ type-check@~0.3.2: prelude-ls "~1.1.2" typescript@^4.8.4: - version "4.8.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6" - integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ== + version "4.9.3" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.9.3.tgz#3aea307c1746b8c384435d8ac36b8a2e580d85db" + integrity sha512-CIfGzTelbKNEnLpLdGFgdyKhG23CKdKgQPOBc+OUNrkJ2vr+KSzsSV5kq5iWhEQbok+quxgGzrAtGWCyU7tHnA== uc.micro@^1.0.1, uc.micro@^1.0.5: version "1.0.6" @@ -831,11 +1113,36 @@ underscore@~1.13.2: resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441" integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A== +v8-to-istanbul@^7.1.0: + version "7.1.2" + resolved "https://registry.yarnpkg.com/v8-to-istanbul/-/v8-to-istanbul-7.1.2.tgz#30898d1a7fa0c84d225a2c1434fb958f290883c1" + integrity sha512-TxNb7YEUwkLXCQYeudi6lgQ/SZrzNO4kMdlqVxaZPUIUjCv6iSSypUQX70kNBSERpQ8fk48+d61FXk+tgqcWow== + dependencies: + "@types/istanbul-lib-coverage" "^2.0.1" + convert-source-map "^1.6.0" + source-map "^0.7.3" + +which@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" + integrity sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA== + dependencies: + isexe "^2.0.0" + word-wrap@~1.2.3: version "1.2.3" resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + wrappy@1: version "1.0.2" resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" @@ -846,7 +1153,35 @@ xmlcreate@^2.0.4: resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be" integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg== +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + yallist@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== + +yargs-parser@^20.0.0, yargs-parser@^20.2.2: + version "20.2.9" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.9.tgz#2eb7dc3b0289718fc295f362753845c41a0c94ee" + integrity sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w== + +yargs@^16.0.0: + version "16.2.0" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.2.0.tgz#1c82bf0f6b6a66eafce7ef30e376f49a12477f66" + integrity sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.0" + y18n "^5.0.5" + yargs-parser "^20.2.2" + +yocto-queue@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b" + integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q== From 955f090f9f0e69300c3bc331c52a426b0dec5dab Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 13:06:57 -0800 Subject: [PATCH 188/469] Retire the visibility group "//mediapipe/framework:mediapipe_internal". PiperOrigin-RevId: 493687025 --- mediapipe/framework/profiler/BUILD | 4 +--- mediapipe/framework/tool/BUILD | 19 +++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 2947b9844..3b6976fc8 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -291,9 +291,7 @@ cc_library( "-ObjC++", ], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/flags:flag", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 52d04b4b1..453b5a0e8 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -90,7 +90,7 @@ mediapipe_proto_library( name = "packet_generator_wrapper_calculator_proto", srcs = ["packet_generator_wrapper_calculator.proto"], def_py_proto = False, - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:packet_generator_proto", @@ -120,13 +120,13 @@ cc_library( name = "fill_packet_set", srcs = ["fill_packet_set.cc"], hdrs = ["fill_packet_set.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ + ":status_util", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", - "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/memory", ], ) @@ -162,7 +162,6 @@ cc_library( cc_test( name = "executor_util_test", srcs = ["executor_util_test.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":executor_util", "//mediapipe/framework/port:gtest_main", @@ -173,7 +172,7 @@ cc_test( cc_library( name = "options_map", hdrs = ["options_map.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe:__subpackages__"], deps = [ ":type_util", "//mediapipe/framework:calculator_cc_proto", @@ -193,7 +192,7 @@ cc_library( name = "options_field_util", srcs = ["options_field_util.cc"], hdrs = ["options_field_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":field_data_cc_proto", ":name_util", @@ -216,7 +215,7 @@ cc_library( name = "options_syntax_util", srcs = ["options_syntax_util.cc"], hdrs = ["options_syntax_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":name_util", ":options_field_util", @@ -235,8 +234,9 @@ cc_library( name = "options_util", srcs = ["options_util.cc"], hdrs = ["options_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ + ":name_util", ":options_field_util", ":options_map", ":options_registry", @@ -254,7 +254,6 @@ cc_library( "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:name_util", "@com_google_absl//absl/strings", ], ) @@ -323,7 +322,7 @@ mediapipe_cc_test( cc_library( name = "packet_generator_wrapper_calculator", srcs = ["packet_generator_wrapper_calculator.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":packet_generator_wrapper_calculator_cc_proto", "//mediapipe/framework:calculator_base", From ea74db86dd2926278c9d2486bf58e23caa1a97a6 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 14:04:37 -0800 Subject: [PATCH 189/469] Tensor: clang tidy fixes. PiperOrigin-RevId: 493703073 --- mediapipe/framework/formats/tensor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index c31eba350..9e1406dbb 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -551,7 +551,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { }); } else #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - + { // Transfer data from texture if not transferred from SSBO/MTLBuffer // yet. if (valid_ & kValidOpenGlTexture2d) { @@ -582,6 +582,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { } }); } + } #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 valid_ |= kValidCpu; } From 7faee517c4606e647ae63ae4296fae54d08f6abb Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 14:31:02 -0800 Subject: [PATCH 190/469] Tensor: Move general CPU/SSBO tensor storage into Ahwb-backed CPU/SSBO storage. PiperOrigin-RevId: 493710495 --- mediapipe/framework/formats/tensor.h | 1 + mediapipe/framework/formats/tensor_ahwb.cc | 40 +++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 3ed72c6fd..151aa299d 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -418,6 +418,7 @@ class Tensor { void ReleaseAhwbStuff(); void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; + void MoveCpuOrSsboToAhwb() const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 90d89c40a..21bae9593 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -215,10 +215,15 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " "supported on targe system."; + bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; valid_ |= kValidAHardwareBuffer; - if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + if (transfer) { + MoveCpuOrSsboToAhwb(); + } else { + if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + } return {ahwb_, ssbo_written_, &fence_fd_, // The FD is created for SSBO -> AHWB synchronization. @@ -303,6 +308,39 @@ bool Tensor::AllocateAhwbMapToSsbo() const { return false; } +// Moves Cpu/Ssbo resource under the Ahwb backed memory. +void Tensor::MoveCpuOrSsboToAhwb() const { + void* dest = nullptr; + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_lock( + ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); + CHECK(error == 0) << "AHardwareBuffer_lock " << error; + } + if (valid_ & kValidOpenGlBuffer) { + gl_context_->Run([this, dest]() { + glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); + const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), + GL_MAP_READ_BIT); + std::memcpy(dest, src, bytes()); + glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); + glDeleteBuffers(1, &opengl_buffer_); + }); + opengl_buffer_ = GL_INVALID_INDEX; + gl_context_ = nullptr; + } else if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + } else { + LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; + } + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_unlock(ahwb_, nullptr); + CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + } +} + // SSBO is created on top of AHWB. A fence is inserted into the GPU queue before // the GPU task that is going to read from the SSBO. When the writing into AHWB // is finished then the GPU reads from the SSBO. From ef1507ed5df48f00daaa2111518cc0a32faec3a6 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 14:43:59 -0800 Subject: [PATCH 191/469] Retire the visibility group "//mediapipe/framework:mediapipe_internal". PiperOrigin-RevId: 493713823 --- mediapipe/tasks/cc/core/BUILD | 5 ++++- mediapipe/tasks/cc/text/tokenizers/BUILD | 2 +- mediapipe/tasks/testdata/text/BUILD | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index f8004d257..d440271df 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -309,7 +309,10 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = [ + "//mediapipe/calculators:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 7f1ea2848..92fac8eaa 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/framework:mediapipe_internal"]) +package(default_visibility = ["//mediapipe/calculators/tensor:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 081e63c2c..a0131c056 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -18,7 +18,10 @@ load( ) package( - default_visibility = ["//mediapipe/framework:mediapipe_internal"], + default_visibility = [ + "//mediapipe/calculators/tensor:__subpackages__", + "//mediapipe/tasks:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) From 91664eb254bb44adb03b1e4823e0a5250d1f3837 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 14:52:58 -0800 Subject: [PATCH 192/469] Object Detector deduplication PiperOrigin-RevId: 493716159 --- mediapipe/calculators/util/BUILD | 17 +++ .../util/detections_deduplicate_calculator.cc | 114 ++++++++++++++++++ .../tasks/cc/vision/object_detector/BUILD | 1 + .../object_detector/object_detector_graph.cc | 7 +- 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 mediapipe/calculators/util/detections_deduplicate_calculator.cc diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 43eadd53b..1529ead8a 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -456,6 +456,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detections_deduplicate_calculator", + srcs = [ + "detections_deduplicate_calculator.cc", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], diff --git a/mediapipe/calculators/util/detections_deduplicate_calculator.cc b/mediapipe/calculators/util/detections_deduplicate_calculator.cc new file mode 100644 index 000000000..2dfa09028 --- /dev/null +++ b/mediapipe/calculators/util/detections_deduplicate_calculator.cc @@ -0,0 +1,114 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +struct BoundingBoxHash { + size_t operator()(const LocationData::BoundingBox& bbox) const { + return std::hash{}(bbox.xmin()) ^ std::hash{}(bbox.ymin()) ^ + std::hash{}(bbox.width()) ^ std::hash{}(bbox.height()); + } +}; + +struct BoundingBoxEq { + bool operator()(const LocationData::BoundingBox& lhs, + const LocationData::BoundingBox& rhs) const { + return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() && + lhs.width() == rhs.width() && lhs.height() == rhs.height(); + } +}; + +} // namespace + +// This Calculator deduplicates the bunding boxes with exactly the same +// coordinates, and folds the labels into a single Detection proto. Note +// non-maximum-suppression remove the overlapping bounding boxes within a class, +// while the deduplication operation merges bounding boxes from different +// classes. + +// Example config: +// node { +// calculator: "DetectionsDeduplicateCalculator" +// input_stream: "detections" +// output_stream: "deduplicated_detections" +// } +class DetectionsDeduplicateCalculator : public Node { + public: + static constexpr Input> kIn{""}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Open(mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + + absl::Status Process(mediapipe::CalculatorContext* cc) { + const std::vector& raw_detections = kIn(cc).Get(); + absl::flat_hash_map + bbox_to_detections; + std::vector deduplicated_detections; + for (const auto& detection : raw_detections) { + if (!detection.has_location_data() || + !detection.location_data().has_bounding_box()) { + return absl::InvalidArgumentError( + "The location data of Detections must be BoundingBox."); + } + if (bbox_to_detections.contains( + detection.location_data().bounding_box())) { + // The bbox location already exists. Merge the detection labels into + // the existing detection proto. + Detection& deduplicated_detection = + *bbox_to_detections[detection.location_data().bounding_box()]; + deduplicated_detection.mutable_score()->MergeFrom(detection.score()); + deduplicated_detection.mutable_label()->MergeFrom(detection.label()); + deduplicated_detection.mutable_label_id()->MergeFrom( + detection.label_id()); + deduplicated_detection.mutable_display_name()->MergeFrom( + detection.display_name()); + } else { + // The bbox location appears first time. Add the detection to output + // detection vector. + deduplicated_detections.push_back(detection); + bbox_to_detections[detection.location_data().bounding_box()] = + &deduplicated_detections.back(); + } + } + kOut(cc).Send(std::move(deduplicated_detections)); + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index c2dd9995d..224eca520 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -63,6 +63,7 @@ cc_library( "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/calculators/util:detections_deduplicate_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index a1625c16c..fd95bb1ac 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -662,11 +662,16 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { detection_transformation.Out(kPixelDetectionsTag) >> detection_label_id_to_text.In(""); + // Deduplicate Detections with same bounding box coordinates. + auto& detections_deduplicate = + graph.AddNode("DetectionsDeduplicateCalculator"); + detection_label_id_to_text.Out("") >> detections_deduplicate.In(""); + // Outputs the labeled detections and the processed image as the subgraph // output streams. return {{ /* detections= */ - detection_label_id_to_text[Output>("")], + detections_deduplicate[Output>("")], /* image= */ preprocessing[Output(kImageTag)], }}; } From 5f97b29b3ba41d1a7765221ed242e0fab9a89751 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 15:23:10 -0800 Subject: [PATCH 193/469] Update Bazel dependencies for Apple PiperOrigin-RevId: 493723833 --- WORKSPACE | 56 +++++++++++++++++++++++++------------------------------ 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index d43394883..bf5e4236b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -320,12 +320,30 @@ http_archive( ], ) -# iOS basic build deps. +# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee +# that the target @zlib//:mini_zlib is available +http_archive( + name = "zlib", + build_file = "//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], + patches = [ + "@//third_party:zlib.diff", + ], + patch_args = [ + "-p1", + ], +) +# iOS basic build deps. http_archive( name = "build_bazel_rules_apple", - sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2", - url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz", + sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911", + url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz", patches = [ # Bypass checking ios unit test runner when building MP ios applications. "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" @@ -339,29 +357,24 @@ load( "@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies", ) - apple_rules_dependencies() load( "@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies", ) - swift_rules_dependencies() -http_archive( - name = "build_bazel_apple_support", - sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35", - urls = [ - "https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz" - ], +load( + "@build_bazel_rules_swift//swift:extras.bzl", + "swift_rules_extra_dependencies", ) +swift_rules_extra_dependencies() load( "@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies", ) - apple_support_dependencies() # More iOS deps. @@ -442,25 +455,6 @@ http_archive( ], ) -# Load Zlib before initializing TensorFlow to guarantee that the target -# @zlib//:mini_zlib is available -http_archive( - name = "zlib", - build_file = "//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], - patches = [ - "@//third_party:zlib.diff", - ], - patch_args = [ - "-p1", - ], -) - # TensorFlow repo should always go after the other external dependencies. # TF on 2022-08-10. _TENSORFLOW_GIT_COMMIT = "af1d5bc4fbb66d9e6cc1cf89503014a99233583b" From a59f0a99243a77c5f1cef684c5cd542c320c59f8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 15:49:08 -0800 Subject: [PATCH 194/469] Make java/C++/python tasks API public visible. PiperOrigin-RevId: 493730506 --- mediapipe/tasks/cc/audio/audio_classifier/BUILD | 4 +--- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 4 +--- mediapipe/tasks/cc/vision/object_detector/BUILD | 4 +--- mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD | 2 +- .../com/google/mediapipe/tasks/components/containers/BUILD | 2 +- .../com/google/mediapipe/tasks/components/processors/BUILD | 2 +- .../java/com/google/mediapipe/tasks/components/utils/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD | 2 +- mediapipe/tasks/python/audio/BUILD | 2 +- mediapipe/tasks/python/audio/core/BUILD | 2 +- mediapipe/tasks/python/components/containers/BUILD | 2 +- mediapipe/tasks/python/components/processors/BUILD | 2 +- mediapipe/tasks/python/components/utils/BUILD | 2 +- mediapipe/tasks/python/core/BUILD | 2 +- mediapipe/tasks/python/text/BUILD | 2 +- mediapipe/tasks/python/text/core/BUILD | 2 +- mediapipe/tasks/python/vision/BUILD | 2 +- mediapipe/tasks/python/vision/core/BUILD | 2 +- 20 files changed, 20 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index f61472413..c575caabe 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_classifier", srcs = ["audio_classifier.cc"], hdrs = ["audio_classifier.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index 6a0f627b2..1dfdd6f1b 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_embedder", srcs = ["audio_embedder.cc"], hdrs = ["audio_embedder.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 224eca520..77373303a 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -22,9 +22,7 @@ cc_library( name = "object_detector", srcs = ["object_detector.cc"], hdrs = ["object_detector.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":object_detector_graph", "//mediapipe/calculators/core:concatenate_vector_calculator", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 2d29ccf23..e5d472e8a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index ad17d5552..4d302b950 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index 1f99f1612..b4d453935 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD index b2d27bfa7..6c724106f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 01b1f653a..31f885267 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 5b10e9aab..31cd2c89a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) # The native library of all MediaPipe text tasks. cc_binary( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 6161fe032..f469aed0c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index ce7c5ce08..6dda7a53c 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 3cb9cb8e8..5b4203d7b 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 9d275e167..7108617ff 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index f87a579b0..695f6df91 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index 31114f326..1a18531c6 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index fc0018ab1..447189d6f 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index e2a51cdbd..9d5d23261 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/core/BUILD b/mediapipe/tasks/python/text/core/BUILD index 072a0c7d8..e76bd4b6d 100644 --- a/mediapipe/tasks/python/text/core/BUILD +++ b/mediapipe/tasks/python/text/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 241ca4341..5f4aa38ff 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index e2b2b3dec..18df690a0 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) From a0efcb47f23666f84448d82fcede6dab9fdfbf55 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 16:37:08 -0800 Subject: [PATCH 195/469] internal change PiperOrigin-RevId: 493742399 --- .../tasks/cc/components/containers/BUILD | 13 + .../components/containers/detection_result.cc | 73 ++++++ .../components/containers/detection_result.h | 52 ++++ .../tasks/cc/components/containers/rect.cc | 34 +++ .../tasks/cc/components/containers/rect.h | 29 ++- .../cc/vision/core/base_vision_task_api.h | 4 +- .../cc/vision/core/image_processing_options.h | 3 +- ...hand_landmarks_deduplication_calculator.cc | 14 +- .../hand_landmarker/hand_landmarker_test.cc | 4 +- .../image_classifier/image_classifier_test.cc | 19 +- .../image_embedder/image_embedder_test.cc | 6 +- .../image_segmenter/image_segmenter_test.cc | 4 +- .../tasks/cc/vision/object_detector/BUILD | 1 + .../vision/object_detector/object_detector.cc | 15 +- .../vision/object_detector/object_detector.h | 12 +- .../object_detector/object_detector_test.cc | 227 ++++++++++-------- .../tasks/cc/vision/utils/landmarks_utils.cc | 8 +- .../tasks/cc/vision/utils/landmarks_utils.h | 10 +- 18 files changed, 377 insertions(+), 151 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.cc create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.h create mode 100644 mediapipe/tasks/cc/components/containers/rect.cc diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 35d3f4785..0750a1482 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -18,6 +18,7 @@ licenses(["notice"]) cc_library( name = "rect", + srcs = ["rect.cc"], hdrs = ["rect.h"], ) @@ -41,6 +42,18 @@ cc_library( ], ) +cc_library( + name = "detection_result", + srcs = ["detection_result.cc"], + hdrs = ["detection_result.h"], + deps = [ + ":category", + ":rect", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + ], +) + cc_library( name = "embedding_result", srcs = ["embedding_result.cc"], diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc new file mode 100644 index 000000000..43c8ca0f5 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +#include + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +constexpr int kDefaultCategoryIndex = -1; + +Detection ConvertToDetectionResult( + const mediapipe::Detection& detection_proto) { + Detection detection; + for (int idx = 0; idx < detection_proto.score_size(); ++idx) { + detection.categories.push_back( + {/* index= */ detection_proto.label_id_size() > idx + ? detection_proto.label_id(idx) + : kDefaultCategoryIndex, + /* score= */ detection_proto.score(idx), + /* category_name */ detection_proto.label_size() > idx + ? detection_proto.label(idx) + : "", + /* display_name */ detection_proto.display_name_size() > idx + ? detection_proto.display_name(idx) + : ""}); + } + Rect bounding_box; + if (detection_proto.location_data().has_bounding_box()) { + mediapipe::LocationData::BoundingBox bounding_box_proto = + detection_proto.location_data().bounding_box(); + bounding_box.left = bounding_box_proto.xmin(); + bounding_box.top = bounding_box_proto.ymin(); + bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width(); + bounding_box.bottom = + bounding_box_proto.ymin() + bounding_box_proto.height(); + } + detection.bounding_box = bounding_box; + return detection; +} + +DetectionResult ConvertToDetectionResult( + std::vector detections_proto) { + DetectionResult detection_result; + detection_result.detections.reserve(detections_proto.size()); + for (const auto& detection_proto : detections_proto) { + detection_result.detections.push_back( + ConvertToDetectionResult(detection_proto)); + } + return detection_result; +} +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/detection_result.h b/mediapipe/tasks/cc/components/containers/detection_result.h new file mode 100644 index 000000000..546f324d6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +// Detection for a single bounding box. +struct Detection { + // A vector of detected categories. + std::vector categories; + // The bounding box location. + Rect bounding_box; +}; + +// Detection results of a model. +struct DetectionResult { + // A vector of Detections. + std::vector detections; +}; + +// Utility function to convert from Detection proto to Detection struct. +Detection ConvertToDetection(const mediapipe::Detection& detection_proto); + +// Utility function to convert from list of Detection proto to DetectionResult +// struct. +DetectionResult ConvertToDetectionResult( + std::vector detections_proto); + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/rect.cc b/mediapipe/tasks/cc/components/containers/rect.cc new file mode 100644 index 000000000..4a94832a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/rect.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +RectF ToRectF(const Rect& rect, int image_height, int image_width) { + return RectF{static_cast(rect.left) / image_width, + static_cast(rect.top) / image_height, + static_cast(rect.right) / image_width, + static_cast(rect.bottom) / image_height}; +} + +Rect ToRect(const RectF& rect, int image_height, int image_width) { + return Rect{static_cast(rect.left * image_width), + static_cast(rect.top * image_height), + static_cast(rect.right * image_width), + static_cast(rect.bottom * image_height)}; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 3f5432cf2..551d91588 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,20 +16,47 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include + namespace mediapipe::tasks::components::containers { +constexpr float kRectFTolerance = 1e-4; + // Defines a rectangle, used e.g. as part of detection results or as input // region-of-interest. // +struct Rect { + int left; + int top; + int right; + int bottom; +}; + +inline bool operator==(const Rect& lhs, const Rect& rhs) { + return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right && + lhs.bottom == rhs.bottom; +} + // The coordinates are normalized wrt the image dimensions, i.e. generally in // [0,1] but they may exceed these bounds if describing a region overlapping the // image. The origin is on the top-left corner of the image. -struct Rect { +struct RectF { float left; float top; float right; float bottom; }; +inline bool operator==(const RectF& lhs, const RectF& rhs) { + return abs(lhs.left - rhs.left) < kRectFTolerance && + abs(lhs.top - rhs.top) < kRectFTolerance && + abs(lhs.right - rhs.right) < kRectFTolerance && + abs(lhs.bottom - rhs.bottom) < kRectFTolerance; +} + +RectF ToRectF(const Rect& rect, int image_height, int image_width); + +Rect ToRect(const RectF& rect, int image_height, int image_width); + } // namespace mediapipe::tasks::components::containers #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c3c0a0261..a86b2cca8 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { if (roi.left >= roi.right || roi.top >= roi.bottom) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect with left < right and top < bottom.", + "Expected RectF with left < right and top < bottom.", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect values to be in [0,1].", + "Expected RectF values to be in [0,1].", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } normalized_rect.set_x_center((roi.left + roi.right) / 2.0); diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 7e764c1fe..1983272fc 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -35,7 +35,8 @@ struct ImageProcessingOptions { // the full image is used. // // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. - std::optional region_of_interest = std::nullopt; + std::optional region_of_interest = + std::nullopt; // The rotation to apply to the image (or cropped region-of-interest), in // degrees clockwise. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 564184c64..266ce223f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -44,7 +44,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::DuplicatesFinder; @@ -126,7 +126,7 @@ absl::StatusOr HandBaselineDistance( return distance; } -Rect CalculateBound(const NormalizedLandmarkList& list) { +RectF CalculateBound(const NormalizedLandmarkList& list) { constexpr float kMinInitialValue = std::numeric_limits::max(); constexpr float kMaxInitialValue = std::numeric_limits::lowest(); @@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return Rect{/*left=*/bounding_box_left, - /*top=*/bounding_box_top, - /*right=*/bounding_box_right, - /*bottom=*/bounding_box_bottom}; + return RectF{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect @@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder { const int num = multi_landmarks.size(); std::vector baseline_distances; baseline_distances.reserve(num); - std::vector bounds; + std::vector bounds; bounds.reserve(num); for (const NormalizedLandmarkList& list : multi_landmarks) { ASSIGN_OR_RETURN(const float baseline_distance, diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index fa49a4c1f..94d1b1c12 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -50,7 +50,7 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::EqualsProto; @@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->running_mode = core::RunningMode::IMAGE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, HandLandmarker::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = hand_landmarker->Detect(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 1144e9032..7aa2a148c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -52,7 +52,7 @@ namespace { using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::Category; using ::mediapipe::tasks::components::containers::Classifications; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( @@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the chair, with 90° anti-clockwise rotation. - Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, + /*bottom=*/0.3049}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; @@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { ImageClassifier::Create(std::move(options))); // Invalid: left > right. - Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect values to be in [0,1]")); + HasSubstr("Expected RectF values to be in [0,1]")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { @@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 6098a9a70..dd602bef5 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -41,7 +41,7 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". - Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. @@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger_rotated.jpg"))); // Region-of-interest corresponding to burger_crop.jpg. - Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index d5ea088a1..f9618c1b1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -47,7 +47,7 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = segmenter->Segment(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 77373303a..5269796ae 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -33,6 +33,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/containers:detection_result", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index dd19237ff..e0222dd70 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; @@ -129,7 +131,8 @@ absl::StatusOr> ObjectDetector::Create( Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(detections_packet.Get>(), + result_callback(ConvertToDetectionResult( + detections_packet.Get>()), image_packet.Get(), detections_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -144,7 +147,7 @@ absl::StatusOr> ObjectDetector::Create( std::move(packets_callback)); } -absl::StatusOr> ObjectDetector::Detect( +absl::StatusOr ObjectDetector::Detect( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -161,10 +164,11 @@ absl::StatusOr> ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } -absl::StatusOr> ObjectDetector::DetectForVideo( +absl::StatusOr ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -185,7 +189,8 @@ absl::StatusOr> ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } absl::Status ObjectDetector::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 44ce68ed9..249a2ebf5 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -36,6 +37,10 @@ namespace mediapipe { namespace tasks { namespace vision { +// Alias the shared DetectionResult struct as result typo. +using ObjectDetectorResult = + ::mediapipe::tasks::components::containers::DetectionResult; + // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { // Base options for configuring MediaPipe Tasks, such as specifying the TfLite @@ -79,8 +84,7 @@ struct ObjectDetectorOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, int64)> result_callback = nullptr; }; @@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. - absl::StatusOr> Detect( + absl::StatusOr Detect( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. - absl::StatusOr> DetectForVideo( + absl::StatusOr DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options = std::nullopt); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 1747685dd..798e3f238 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -65,10 +66,14 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; +using ::mediapipe::tasks::components::containers::Detection; +using ::mediapipe::tasks::components::containers::DetectionResult; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; +using DetectionProto = mediapipe::Detection; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kMobileSsdWithMetadata[] = @@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] = // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. // If the proto definition changes, please also change this function. -void ExpectApproximatelyEqual(const std::vector& actual, - const std::vector& expected) { +void ExpectApproximatelyEqual(const ObjectDetectorResult& actual, + const ObjectDetectorResult& expected) { const float kPrecision = 1e-6; - EXPECT_EQ(actual.size(), expected.size()); - for (int i = 0; i < actual.size(); ++i) { - const Detection& a = actual[i]; - const Detection& b = expected[i]; - EXPECT_THAT(a.location_data().bounding_box(), - EqualsProto(b.location_data().bounding_box())); - EXPECT_EQ(a.label_size(), 1); - EXPECT_EQ(b.label_size(), 1); - EXPECT_EQ(a.label(0), b.label(0)); - EXPECT_EQ(a.score_size(), 1); - EXPECT_EQ(b.score_size(), 1); - EXPECT_NEAR(a.score(0), b.score(0), kPrecision); + EXPECT_EQ(actual.detections.size(), expected.detections.size()); + for (int i = 0; i < actual.detections.size(); ++i) { + const Detection& a = actual.detections[i]; + const Detection& b = expected.detections[i]; + EXPECT_EQ(a.bounding_box, b.bounding_box); + EXPECT_EQ(a.categories.size(), 1); + EXPECT_EQ(b.categories.size(), 1); + EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name); + EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision); } } -std::vector GenerateMobileSsdNoImageResizingFullExpectedResults() { - return {ParseTextProtoOrDie(R"pb( +std::vector +GenerateMobileSsdNoImageResizingFullExpectedResults() { + return {ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6328125 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.59765625 location_data { format: BOUNDING_BOX bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.5 location_data { format: BOUNDING_BOX bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "dog" score: 0.48828125 location_data { @@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = running_mode; options->result_callback = - [](absl::StatusOr> detections, - const Image& image, int64 timestamp_ms) {}; + [](absl::StatusOr detections, const Image& image, + int64 timestamp_ms) {}; absl::StatusOr> object_detector = ObjectDetector::Create(std::move(options)); EXPECT_EQ(object_detector.status().code(), @@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.69921875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.64453125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.51171875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.48828125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.69921875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.64453125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.51171875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.48828125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } + })pb")})); } TEST_F(ImageModeTest, SucceedsEfficientDetModel) { @@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.7578125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.72265625 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.6289063 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.5859375 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7578125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.72265625 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.6289063 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.5859375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } + })pb")})); } TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { @@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, GenerateMobileSsdNoImageResizingFullExpectedResults()); + results, ConvertToDetectionResult( + GenerateMobileSsdNoImageResizingFullExpectedResults())); } TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { @@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6531269142 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { @@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, - {full_expected_results[0], full_expected_results[1], - full_expected_results[2]}); + + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1], + full_expected_results[2]})); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { @@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithDenylistOption) { @@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithRotation) { @@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.7109375 location_data { format: BOUNDING_BOX bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, FailsWithRegionOfInterest) { @@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); @@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) { for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->DetectForVideo(image, i)); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } MP_ASSERT_OK(object_detector->Close()); } @@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->running_mode = core::RunningMode::LIVE_STREAM; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); @@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) { auto options = std::make_unique(); options->max_results = 2; options->running_mode = core::RunningMode::LIVE_STREAM; - std::vector> detection_results; + std::vector detection_results; std::vector> image_sizes; std::vector timestamps; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( - absl::StatusOr> detections, const Image& image, + absl::StatusOr detections, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(detections.status()); detection_results.push_back(std::move(detections).value()); @@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) { // number of iterations. ASSERT_LE(detection_results.size(), iterations); ASSERT_GT(detection_results.size(), 0); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); for (const auto& detection_result : detection_results) { ExpectApproximatelyEqual( - detection_result, {full_expected_results[0], full_expected_results[1]}); + detection_result, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1]})); } for (const auto& image_size : image_sizes) { EXPECT_EQ(image_size.first, image.width()); diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index 2ce9e2454..fe4e63824 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -22,13 +22,13 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; -float CalculateArea(const Rect& rect) { +float CalculateArea(const RectF& rect) { return (rect.right - rect.left) * (rect.bottom - rect.top); } -float CalculateIntersectionArea(const Rect& a, const Rect& b) { +float CalculateIntersectionArea(const RectF& a, const RectF& b) { const float intersection_left = std::max(a.left, b.left); const float intersection_top = std::max(a.top, b.top); const float intersection_right = std::min(a.right, b.right); @@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) { std::max(intersection_right - intersection_left, 0.0); } -float CalculateIOU(const Rect& a, const Rect& b) { +float CalculateIOU(const RectF& a, const RectF& b) { const float area_a = CalculateArea(a); const float area_b = CalculateArea(b); if (area_a <= 0 || area_b <= 0) return 0.0; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index 73114d2ef..4d1fac62f 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -27,15 +27,15 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { // Calculates intersection over union for two bounds. -float CalculateIOU(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIOU(const components::containers::RectF& a, + const components::containers::RectF& b); // Calculates area for face bound -float CalculateArea(const components::containers::Rect& rect); +float CalculateArea(const components::containers::RectF& rect); // Calucates intersection area of two face bounds -float CalculateIntersectionArea(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIntersectionArea(const components::containers::RectF& a, + const components::containers::RectF& b); } // namespace mediapipe::tasks::vision::utils #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ From 700c7b4b2258d3a01bf8424146a4cf94e8ca7282 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 18:54:34 -0800 Subject: [PATCH 196/469] Internal refactoring for TextEmbedder. PiperOrigin-RevId: 493766612 --- .../tasks/cc/components/processors/BUILD | 3 + .../cc/components/processors/proto/BUILD | 6 + .../processors/proto/text_model_type.proto | 31 +++++ .../text_preprocessing_graph_options.proto | 15 +-- .../processors/text_preprocessing_graph.cc | 126 +++++------------- mediapipe/tasks/cc/text/utils/BUILD | 40 ++++++ .../tasks/cc/text/utils/text_model_utils.cc | 119 +++++++++++++++++ .../tasks/cc/text/utils/text_model_utils.h | 33 +++++ .../cc/text/utils/text_model_utils_test.cc | 108 +++++++++++++++ 9 files changed, 375 insertions(+), 106 deletions(-) create mode 100644 mediapipe/tasks/cc/components/processors/proto/text_model_type.proto create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils.cc create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils.h create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils_test.cc diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 185bf231b..cec44a9e3 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -150,9 +150,12 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/utils:text_model_utils", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index f48c4bad8..816ba47e3 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -60,10 +60,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "text_model_type_proto", + srcs = ["text_model_type.proto"], +) + mediapipe_proto_library( name = "text_preprocessing_graph_options_proto", srcs = ["text_preprocessing_graph_options.proto"], deps = [ + ":text_model_type_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], diff --git a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto new file mode 100644 index 000000000..7ffc0db07 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +message TextModelType { + // TFLite text models supported by MediaPipe tasks. + enum ModelType { + UNSPECIFIED_MODEL = 0; + // A BERT-based model. + BERT_MODEL = 1; + // A model expecting input passed through a regex-based tokenizer. + REGEX_MODEL = 2; + // A model taking a string tensor input. + STRING_MODEL = 3; + } +} diff --git a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index a67cfd8a9..b610f7757 100644 --- a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -18,25 +18,16 @@ syntax = "proto2"; package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto"; message TextPreprocessingGraphOptions { extend mediapipe.CalculatorOptions { optional TextPreprocessingGraphOptions ext = 476978751; } - // The type of text preprocessor required for the TFLite model. - enum PreprocessorType { - UNSPECIFIED_PREPROCESSOR = 0; - // Used for the BertPreprocessorCalculator. - BERT_PREPROCESSOR = 1; - // Used for the RegexPreprocessorCalculator. - REGEX_PREPROCESSOR = 2; - // Used for the TextToTensorCalculator. - STRING_PREPROCESSOR = 3; - } - optional PreprocessorType preprocessor_type = 1; + optional TextModelType.ModelType model_type = 1; // The maximum input sequence length for the TFLite model. Used with - // BERT_PREPROCESSOR and REGEX_PREPROCESSOR. + // BERT_MODEL and REGEX_MODEL. optional int32 max_seq_len = 2; } diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index de16375bd..f6c15c441 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -25,15 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" -namespace mediapipe { -namespace tasks { -namespace components { -namespace processors { - +namespace mediapipe::tasks::components::processors { namespace { using ::mediapipe::api2::Input; @@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::processors::proto::TextModelType; using ::mediapipe::tasks::components::processors::proto:: TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::text::utils::GetModelType; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -constexpr int kNumInputTensorsForBert = 3; -constexpr int kNumInputTensorsForRegex = 1; - -// Gets the name of the MediaPipe calculator associated with -// `preprocessor_type`. -absl::StatusOr GetCalculatorNameFromPreprocessorType( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) { - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: +// Gets the name of the MediaPipe preprocessor calculator associated with +// `model_type`. +absl::StatusOr GetCalculatorNameFromModelType( + TextModelType::ModelType model_type) { + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", + absl::StatusCode::kInvalidArgument, "Unspecified model type", MediaPipeTasksStatus::kInvalidArgumentError); - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextModelType::BERT_MODEL: return "BertPreprocessorCalculator"; - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: + case TextModelType::REGEX_MODEL: return "RegexPreprocessorCalculator"; - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + case TextModelType::STRING_MODEL: return "TextToTensorCalculator"; } } -// Determines the PreprocessorType for the model based on its metadata as well -// as its input tensors' type and count. Returns an error if there is no -// compatible preprocessor. -absl::StatusOr -GetPreprocessorType(const ModelResources& model_resources) { - const tflite::SubGraph& model_graph = - *(*model_resources.GetTfLiteModel()->subgraphs())[0]; - bool all_int32_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; - }); - bool all_string_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; - }); - if (!all_int32_tensors && !all_string_tensors) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "All input tensors should have type int32 or all should have type " - "string", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); - } - if (all_string_tensors) { - return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; - } - - // Otherwise, all tensors should have type int32 - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); - if (metadata_extractor->GetModelMetadata() == nullptr || - metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Text models with int32 input tensors require TFLite Model " - "Metadata but none was found", - MediaPipeTasksStatus::kMetadataNotFoundError); - } - - if (model_graph.inputs()->size() == kNumInputTensorsForBert) { - return TextPreprocessingGraphOptions::BERT_PREPROCESSOR; - } - - if (model_graph.inputs()->size() == kNumInputTensorsForRegex) { - return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR; - } - - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::Substitute("Models with int32 input tensors should take exactly $0 " - "or $1 input tensors, but found $2", - kNumInputTensorsForBert, kNumInputTensorsForRegex, - model_graph.inputs()->size()), - MediaPipeTasksStatus::kInvalidNumInputTensorsError); -} - // Returns the maximum input sequence length accepted by the TFLite // model that owns `model graph` or returns an error if the model's input // tensors' shape is invalid for text preprocessing. This util assumes that the @@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph( MediaPipeTasksStatus::kInvalidArgumentError); } - ASSIGN_OR_RETURN( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, - GetPreprocessorType(model_resources)); - options.set_preprocessor_type(preprocessor_type); - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + ASSIGN_OR_RETURN(TextModelType::ModelType model_type, + GetModelType(model_resources)); + options.set_model_type(model_type); + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::BERT_MODEL: + case TextModelType::REGEX_MODEL: { ASSIGN_OR_RETURN( int max_seq_len, GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); @@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { absl::StatusOr>> BuildTextPreprocessing( const TextPreprocessingGraphOptions& options, Source text_in, SideSource metadata_extractor_in, Graph& graph) { - ASSIGN_OR_RETURN( - std::string preprocessor_name, - GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); + ASSIGN_OR_RETURN(std::string preprocessor_name, + GetCalculatorNameFromModelType(options.model_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); - switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + switch (options.model_type()) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { + case TextModelType::BERT_MODEL: { text_preprocessor.GetOptions() .set_bert_max_seq_len(options.max_seq_len()); metadata_extractor_in >> text_preprocessor.SideIn(kMetadataExtractorTag); break; } - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::REGEX_MODEL: { text_preprocessor.GetOptions() .set_max_seq_len(options.max_seq_len()); metadata_extractor_in >> @@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { REGISTER_MEDIAPIPE_GRAPH( ::mediapipe::tasks::components::processors::TextPreprocessingGraph); -} // namespace processors -} // namespace components -} // namespace tasks -} // namespace mediapipe +} // namespace mediapipe::tasks::components::processors diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 710e8a984..092a7d450 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -43,3 +43,43 @@ cc_test( "@com_google_absl//absl/container:node_hash_map", ], ) + +cc_library( + name = "text_model_utils", + srcs = ["text_model_utils.cc"], + hdrs = ["text_model_utils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "text_model_utils_test", + srcs = ["text_model_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_model_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.cc b/mediapipe/tasks/cc/text/utils/text_model_utils.cc new file mode 100644 index 000000000..9d0005ec1 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.cc @@ -0,0 +1,119 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe::tasks::text::utils { +namespace { + +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kNumInputTensorsForRegex = 1; +constexpr int kNumInputTensorsForStringPreprocessor = 1; + +// Determines the ModelType for a model with int32 input tensors based +// on the number of input tensors. Returns an error if there is missing metadata +// or an invalid number of input tensors. +absl::StatusOr GetIntTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text models with int32 input tensors require TFLite Model " + "Metadata but none was found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + + if (num_input_tensors == kNumInputTensorsForBert) { + return TextModelType::BERT_MODEL; + } + + if (num_input_tensors == kNumInputTensorsForRegex) { + return TextModelType::REGEX_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with int32 input tensors should take exactly $0 " + "or $1 input tensors, but found $2", + kNumInputTensorsForBert, kNumInputTensorsForRegex, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} + +// Determines the ModelType for a model with string input tensors based +// on the number of input tensors. Returns an error if there is an invalid +// number of input tensors. +absl::StatusOr GetStringTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + if (num_input_tensors == kNumInputTensorsForStringPreprocessor) { + return TextModelType::STRING_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with string input tensors should take exactly " + "$0 tensors, but found $1", + kNumInputTensorsForStringPreprocessor, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} +} // namespace + +absl::StatusOr GetModelType( + const ModelResources& model_resources) { + const tflite::SubGraph& model_graph = + *(*model_resources.GetTfLiteModel()->subgraphs())[0]; + bool all_int32_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; + }); + bool all_string_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; + }); + if (!all_int32_tensors && !all_string_tensors) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "All input tensors should have type int32 or all should have type " + "string", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + if (all_string_tensors) { + return GetStringTensorModelType(model_resources, + model_graph.inputs()->size()); + } + + // Otherwise, all tensors should have type int32 + return GetIntTensorModelType(model_resources, model_graph.inputs()->size()); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.h b/mediapipe/tasks/cc/text/utils/text_model_utils.h new file mode 100644 index 000000000..da8783d33 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +namespace mediapipe::tasks::text::utils { + +// Determines the ModelType for the model based on its metadata as well +// as its input tensors' type and count. Returns an error if there is no +// compatible model type. +absl::StatusOr +GetModelType(const core::ModelResources& model_resources); + +} // namespace mediapipe::tasks::text::utils + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc new file mode 100644 index 000000000..c02f8eca5 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::utils { + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::proto::ExternalFile; + +constexpr absl::string_view kTestModelResourcesTag = "test_model_resources"; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/text/"; +// Classification model with BERT preprocessing. +constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite"; +// Embedding model with BERT preprocessing. +constexpr absl::string_view kMobileBert = + "mobilebert_embedding_with_metadata.tflite"; +// Classification model with regex preprocessing. +constexpr absl::string_view kRegexClassifierPath = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +// Embedding model with regex preprocessing. +constexpr absl::string_view kRegexOneEmbeddingModel = + "regex_one_embedding_with_metadata.tflite"; +// Classification model that takes a string tensor and outputs a bool tensor. +constexpr absl::string_view kStringToBoolModelPath = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +absl::StatusOr GetModelTypeFromFile( + absl::string_view file_name) { + auto model_file = std::make_unique(); + model_file->set_file_name(GetFullPath(file_name)); + ASSIGN_OR_RETURN(auto model_resources, + ModelResources::Create(std::string(kTestModelResourcesTag), + std::move(model_file))); + return GetModelType(*model_resources); +} + +} // namespace + +class TextModelUtilsTest : public tflite_shims::testing::Test {}; + +TEST_F(TextModelUtilsTest, BertClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kBertClassifierPath)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, BertEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexClassifierPath)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexOneEmbeddingModel)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, StringInputModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kStringToBoolModelPath)); + ASSERT_EQ(model_type, TextModelType::STRING_MODEL); +} + +} // namespace mediapipe::tasks::text::utils From 24c8fa97e9aeb75ac6344957aff2a2d5b953061b Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Wed, 7 Dec 2022 19:04:31 -0800 Subject: [PATCH 197/469] Internal change PiperOrigin-RevId: 493768013 --- mediapipe/examples/ios/faceeffect/BUILD | 4 ++-- mediapipe/examples/ios/facemeshgpu/BUILD | 2 +- mediapipe/examples/ios/handtrackinggpu/BUILD | 2 +- mediapipe/examples/ios/iristrackinggpu/BUILD | 2 +- mediapipe/examples/ios/posetrackinggpu/BUILD | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index e0c3abb86..7d3a75cc6 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -74,10 +74,12 @@ objc_library( ], features = ["-layering_check"], deps = [ + "//mediapipe/framework/formats:matrix_data_cc_proto", "//third_party/apple_frameworks:AVFoundation", "//third_party/apple_frameworks:CoreGraphics", "//third_party/apple_frameworks:CoreMedia", "//third_party/apple_frameworks:UIKit", + "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", @@ -85,9 +87,7 @@ objc_library( "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/graphs/face_effect:face_effect_gpu_deps", - "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 02103ce2f..6caf8c09c 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/face_mesh:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 647b7670a..c5b8e7b58 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/hand_tracking:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 056447d63..646d2e5a2 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/iris_tracking:iris_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 86b41ed36..4fbc2280c 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) From 9ae2e43b70188cd73fd478364b71d32410f9c21c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 19:17:14 -0800 Subject: [PATCH 198/469] Open Source the remaining MediaPipe Tasks tests for Web PiperOrigin-RevId: 493769657 --- .../audio_classifier_graph_options.proto | 1 + .../proto/audio_embedder_graph_options.proto | 1 + .../proto/text_classifier_graph_options.proto | 1 + .../proto/text_embedder_graph_options.proto | 1 + .../gesture_classifier_graph_options.proto | 1 + .../gesture_embedder_graph_options.proto | 1 + .../gesture_recognizer_graph_options.proto | 1 + ...and_gesture_recognizer_graph_options.proto | 1 + .../proto/hand_detector_graph_options.proto | 1 + .../proto/hand_landmarker_graph_options.proto | 1 + ...and_landmarks_detector_graph_options.proto | 1 + .../image_classifier_graph_options.proto | 1 + .../proto/image_embedder_graph_options.proto | 1 + .../proto/image_segmenter_graph_options.proto | 1 + .../proto/object_detector_options.proto | 1 + .../tasks/web/audio/audio_classifier/BUILD | 21 ++ .../audio_classifier/audio_classifier_test.ts | 208 ++++++++++++ .../tasks/web/audio/audio_embedder/BUILD | 21 ++ .../audio_embedder/audio_embedder_test.ts | 185 +++++++++++ .../tasks/web/text/text_classifier/BUILD | 22 ++ .../text_classifier/text_classifier_test.ts | 152 +++++++++ mediapipe/tasks/web/text/text_embedder/BUILD | 21 ++ .../text/text_embedder/text_embedder_test.ts | 165 ++++++++++ mediapipe/tasks/web/vision/core/BUILD | 18 + .../vision/core/vision_task_runner.test.ts | 99 ++++++ .../tasks/web/vision/gesture_recognizer/BUILD | 25 ++ .../gesture_recognizer_test.ts | 307 ++++++++++++++++++ .../tasks/web/vision/hand_landmarker/BUILD | 25 ++ .../hand_landmarker/hand_landmarker_test.ts | 251 ++++++++++++++ .../tasks/web/vision/image_classifier/BUILD | 24 ++ .../image_classifier/image_classifier_test.ts | 150 +++++++++ .../tasks/web/vision/image_embedder/BUILD | 21 ++ .../image_embedder/image_embedder_test.ts | 158 +++++++++ .../tasks/web/vision/object_detector/BUILD | 24 ++ .../object_detector/object_detector_test.ts | 229 +++++++++++++ 35 files changed, 2141 insertions(+) create mode 100644 mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts create mode 100644 mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts create mode 100644 mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts create mode 100644 mediapipe/tasks/web/vision/core/vision_task_runner.test.ts create mode 100644 mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts create mode 100644 mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts create mode 100644 mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts create mode 100644 mediapipe/tasks/web/vision/object_detector/object_detector_test.ts diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 5d4ba3296..cc26b3070 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto index 25c5d5474..367a1bf26 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto index 8f4d7eea6..41f87b519 100644 --- a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index e7e3a63c7..fc8e02858 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index dcefa075f..edbabc018 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto index bff4e0a9c..df909a6db 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto index 57d8a3746..fef22c07c 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index 7df2fed37..ae85509da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index a009f2365..bede70da5 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index 51e4e129a..d0edf99c0 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 195f6e5cc..a2d520963 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto"; diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 76315e230..24b126a35 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 72b3e7ee3..24ee866f2 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 4d8100842..5c7d2ec71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index cba58ace8..3f6932f8f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.object_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto"; diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index dc82a4a24..24ef31feb 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,6 +2,7 @@ # # This task takes audio data and outputs the classification result. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "audio_classifier_test_lib", + testonly = True, + srcs = [ + "audio_classifier_test.ts", + ], + deps = [ + ":audio_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_classifier_test", + deps = [":audio_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts new file mode 100644 index 000000000..d5c0a9429 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -0,0 +1,208 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioClassifier} from './audio_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioClassifierFake extends AudioClassifier implements + MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = + 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + private resultProtoVector: ClassificationResult[] = []; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_classifications'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'addDoubleToStream') + .and.callFake((sampleRate, streamName, timestamp) => { + if (streamName === 'sample_rate') { + this.lastSampleRate = sampleRate; + } + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape') + .and.callFake( + (audioData, numChannels, numSamples, streamName, timestamp) => { + expect(numChannels).toBe(1); + }); + spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => { + if (!this.protoVectorListener) return; + this.protoVectorListener(this.resultProtoVector.map( + classificationResult => classificationResult.serializeBinary())); + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } + + /** Sets the Protobuf that will be send to the API. */ + setResults(results: ClassificationResult[]): void { + this.resultProtoVector = results; + } +} + +describe('AudioClassifier', () => { + let audioClassifier: AudioClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioClassifier = new AudioClassifierFake(); + await audioClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(audioClassifier); + verifyListenersRegistered(audioClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await audioClassifier.setOptions({maxResults: 1}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(audioClassifier); + + await audioClassifier.setOptions({maxResults: 5}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(audioClassifier); + }); + + it('merges options', async () => { + await audioClassifier.setOptions({maxResults: 1}); + await audioClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(audioClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([]), 44100); + expect(audioClassifier.lastSampleRate).toEqual(44100); + }); + + it('transforms results', async () => { + const resultProtoVector: ClassificationResult[] = []; + + let classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(0); + let classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + let classificationList = new ClassificationList(); + let clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + classifcations = new Classifications(); + classificationList = new ClassificationList(); + clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + // Invoke the audio classifier + audioClassifier.setResults(resultProtoVector); + const results = audioClassifier.classify(new Float32Array([])); + expect(results.length).toEqual(2); + expect(results[0]).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 0 + }); + expect(results[1]).toEqual({ + classifications: [{ + categories: [{index: 2, score: 0.3, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + timestampMs: 1 + }); + }); + + it('clears results between invocations', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + audioClassifier.setResults([classificationResult]); + + // Invoke the gesture recognizer twice + const classifications1 = audioClassifier.classify(new Float32Array([])); + const classifications2 = audioClassifier.classify(new Float32Array([])); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(classifications1).toEqual(classifications2); + }); +}); diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index dc84d0cd6..0817776c5 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -3,6 +3,7 @@ # This task takes audio input and performs embedding. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -43,3 +44,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "audio_embedder_test_lib", + testonly = True, + srcs = [ + "audio_embedder_test.ts", + ], + deps = [ + ":audio_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_embedder_test", + deps = [":audio_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts new file mode 100644 index 000000000..2f605ff98 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -0,0 +1,185 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult as EmbeddingResultProto, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioEmbedder, AudioEmbedderResult} from './audio_embedder'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + this.attachListenerSpies[1] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_embeddings_out'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addDoubleToStream').and.callFake(sampleRate => { + this.lastSampleRate = sampleRate; + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape'); + } +} + +describe('AudioEmbedder', () => { + let audioEmbedder: AudioEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioEmbedder = new AudioEmbedderFake(); + await audioEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', () => { + verifyGraph(audioEmbedder); + verifyListenersRegistered(audioEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await audioEmbedder.setOptions({quantize: true}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(audioEmbedder); + + await audioEmbedder.setOptions({quantize: undefined}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(audioEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await audioEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + audioEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await audioEmbedder.setOptions({quantize: true}); + await audioEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + audioEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([]), 44100); + expect(audioEmbedder.lastSampleRate).toEqual(44100); + }); + + describe('transforms results', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResultProto(); + resultProto.addEmbeddings(embedding); + + function validateEmbeddingResult( + expectedEmbeddignResult: AudioEmbedderResult[]) { + expect(expectedEmbeddignResult.length).toEqual(1); + + const [embeddingResult] = expectedEmbeddignResult; + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + } + + it('from embeddings strem', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([])); + validateEmbeddingResult(embeddingResults); + }); + + it('from timestamped embeddgins stream', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([]), 42); + validateEmbeddingResult(embeddingResults); + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 07f78ac20..fd97c3db4 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -4,6 +4,7 @@ # BERT-based text classification). load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -45,3 +46,24 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "text_classifier_test_lib", + testonly = True, + srcs = [ + "text_classifier_test.ts", + ], + deps = [ + ":text_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_classifier_test", + deps = [":text_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts new file mode 100644 index 000000000..841bf8c48 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -0,0 +1,152 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextClassifier} from './text_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextClassifierFake extends TextClassifier implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextClassifier', () => { + let textClassifier: TextClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textClassifier = new TextClassifierFake(); + await textClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textClassifier); + verifyListenersRegistered(textClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await textClassifier.setOptions({maxResults: 1}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(textClassifier); + + await textClassifier.setOptions({maxResults: 5}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(textClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await textClassifier.setOptions({maxResults: 1}); + await textClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(textClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textClassifier); + textClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the text classifier + const result = textClassifier.classify('foo'); + + expect(textClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 7d796fb7e..1514944bf 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -4,6 +4,7 @@ # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "text_embedder_test_lib", + testonly = True, + srcs = [ + "text_embedder_test.ts", + ], + deps = [ + ":text_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_embedder_test", + deps = [":text_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts new file mode 100644 index 000000000..04a9b371a --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -0,0 +1,165 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextEmbedder} from './text_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextEmbedder', () => { + let textEmbedder: TextEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textEmbedder = new TextEmbedderFake(); + await textEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textEmbedder); + verifyListenersRegistered(textEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await textEmbedder.setOptions({quantize: true}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(textEmbedder); + + await textEmbedder.setOptions({quantize: undefined}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(textEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await textEmbedder.setOptions({quantize: true}); + await textEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + textEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('transforms results', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + }); + + it('transforms custom quantized values', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingsResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingsResult.embeddings.length).toEqual(1); + expect(embeddingsResult.embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index b389a9b01..e4ea3036f 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Vision Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -22,3 +23,20 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_library( + name = "vision_task_runner_test_lib", + testonly = True, + srcs = ["vision_task_runner.test.ts"], + deps = [ + ":vision_task_runner", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "vision_task_runner_test", + deps = [":vision_task_runner_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts new file mode 100644 index 000000000..6cc9ea328 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -0,0 +1,99 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +import {VisionTaskRunner} from './vision_task_runner'; + +class VisionTaskRunnerFake extends VisionTaskRunner { + baseOptions = new BaseOptionsProto(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + } + + protected override process(): void {} + + override processImageData(image: ImageSource): void { + super.processImageData(image); + } + + override processVideoData(imageFrame: ImageSource, timestamp: number): void { + super.processVideoData(imageFrame, timestamp); + } +} + +describe('VisionTaskRunner', () => { + const streamMode = { + modelAsset: undefined, + useStreamMode: true, + acceleration: undefined, + }; + + const imageMode = { + modelAsset: undefined, + useStreamMode: false, + acceleration: undefined, + }; + + let visionTaskRunner: VisionTaskRunnerFake; + + beforeEach(() => { + visionTaskRunner = new VisionTaskRunnerFake(); + }); + + it('can enable image mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('can enable video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + }); + + it('can clear running mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + // Clear running mode + await visionTaskRunner.setOptions({runningMode: undefined}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('cannot process images with video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(() => { + visionTaskRunner.processImageData({} as HTMLImageElement); + }).toThrowError(/Task is not initialized with image mode./); + }); + + it('cannot process video with image mode', async () => { + // Use default for `useStreamMode` + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + + // Explicitly set to image mode + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + }); +}); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 6e2e56196..aa2f9c366 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more gesture categories, using Gesture Recognizer. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -52,3 +53,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "gesture_recognizer_test_lib", + testonly = True, + srcs = [ + "gesture_recognizer_test.ts", + ], + deps = [ + ":gesture_recognizer", + ":gesture_recognizer_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "gesture_recognizer_test", + tags = ["nomsan"], + deps = [":gesture_recognizer_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts new file mode 100644 index 000000000..c0f0d1554 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -0,0 +1,307 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createGestures(): Uint8Array[] { + const gesturesProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.2); + classification.setIndex(2); + classification.setLabel('gesture_label'); + classification.setDisplayName('gesture_display_name'); + gesturesProto.addClassification(classification); + return [gesturesProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class GestureRecognizerFake extends GestureRecognizer implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_gestures)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('GestureRecognizer', () => { + let gestureRecognizer: GestureRecognizerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + gestureRecognizer = new GestureRecognizerFake(); + await gestureRecognizer.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(gestureRecognizer); + verifyListenersRegistered(gestureRecognizer); + }); + + it('reloads graph when settings are changed', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyListenersRegistered(gestureRecognizer); + + await gestureRecognizer.setOptions({numHands: 5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 5 + ]); + verifyListenersRegistered(gestureRecognizer); + }); + + it('merges options', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + await gestureRecognizer.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyGraph(gestureRecognizer, [ + [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + 0.5 + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof GestureRecognizerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands' + ], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handLandmarksDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['handLandmarkerGraphOptions', 'minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + { + optionPath: ['cannedGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'cannedGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.4, + defaultValue: undefined + }, + { + optionPath: ['customGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'customGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.5, + defaultValue: undefined, + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): GestureRecognizerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestureRecognizer.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(gestures).toEqual({ + 'gestures': [[{ + 'score': 0.2, + 'index': 2, + 'categoryName': 'gesture_label', + 'displayName': 'gesture_display_name' + }]], + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer twice + const gestures1 = gestureRecognizer.recognize({} as HTMLImageElement); + const gestures2 = gestureRecognizer.recognize({} as HTMLImageElement); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(gestures2).toEqual(gestures1); + }); +}); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 520898e34..d1f1e48f3 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more hand categories, using Hand Landmarker. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -47,3 +48,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "hand_landmarker_test_lib", + testonly = True, + srcs = [ + "hand_landmarker_test.ts", + ], + deps = [ + ":hand_landmarker", + ":hand_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "hand_landmarker_test", + tags = ["nomsan"], + deps = [":hand_landmarker_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts new file mode 100644 index 000000000..fc26680e0 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -0,0 +1,251 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {HandLandmarker} from './hand_landmarker'; +import {HandLandmarkerOptions} from './hand_landmarker_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_hands)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('HandLandmarker', () => { + let handLandmarker: HandLandmarkerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + handLandmarker = new HandLandmarkerFake(); + await handLandmarker.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(handLandmarker); + verifyListenersRegistered(handLandmarker); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 1}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 1]); + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 5}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 5]); + verifyListenersRegistered(handLandmarker); + }); + + it('merges options', async () => { + await handLandmarker.setOptions({numHands: 1}); + await handLandmarker.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(handLandmarker, [ + 'handDetectorGraphOptions', + {numHands: 1, baseOptions: undefined, minDetectionConfidence: 0.5} + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof HandLandmarkerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: ['handDetectorGraphOptions', 'numHands'], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: ['handDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: + ['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): HandLandmarkerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + + await handLandmarker.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(handLandmarker); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker + const landmarks = handLandmarker.detect({} as HTMLImageElement); + expect(handLandmarker.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(landmarks).toEqual({ + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker twice + const landmarks1 = handLandmarker.detect({} as HTMLImageElement); + const landmarks2 = handLandmarker.detect({} as HTMLImageElement); + + // Verify that hands2 is not a concatenation of all previously returned + // hands. + expect(landmarks1).toEqual(landmarks2); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 848c162ae..310575964 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -3,6 +3,7 @@ # This task takes video or image frames and outputs the classification result. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_classifier_test_lib", + testonly = True, + srcs = [ + "image_classifier_test.ts", + ], + deps = [ + ":image_classifier", + ":image_classifier_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_classifier_test", + tags = ["nomsan"], + deps = [":image_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts new file mode 100644 index 000000000..2041a0cef --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -0,0 +1,150 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageClassifier} from './image_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageClassifierFake extends ImageClassifier implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageClassifier', () => { + let imageClassifier: ImageClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageClassifier = new ImageClassifierFake(); + await imageClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageClassifier); + verifyListenersRegistered(imageClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await imageClassifier.setOptions({maxResults: 1}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(imageClassifier); + + await imageClassifier.setOptions({maxResults: 5}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(imageClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageClassifier.setOptions({maxResults: 1}); + await imageClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyGraph( + imageClassifier, [['classifierOptions', 'displayNamesLocale'], 'en']); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageClassifier); + imageClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the image classifier + const result = imageClassifier.classify({} as HTMLImageElement); + + expect(imageClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 6c9d80fb1..de4785e6c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -3,6 +3,7 @@ # This task performs embedding extraction on images. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -45,3 +46,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_embedder_test_lib", + testonly = True, + srcs = [ + "image_embedder_test.ts", + ], + deps = [ + ":image_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_embedder_test", + deps = [":image_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts new file mode 100644 index 000000000..cafe0f3d8 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -0,0 +1,158 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageEmbedder} from './image_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageEmbedder', () => { + let imageEmbedder: ImageEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageEmbedder = new ImageEmbedderFake(); + await imageEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageEmbedder); + verifyListenersRegistered(imageEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: true}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: undefined}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(imageEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('overrides options', async () => { + await imageEmbedder.setOptions({quantize: true}); + await imageEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + imageEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + describe('transforms result', () => { + beforeEach(() => { + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + embedding.setFloatEmbedding(floatEmbedding); + + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(42); + + // Pass the test data to our listener + imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageEmbedder); + imageEmbedder.protoListener!(resultProto.serializeBinary()); + }); + }); + + it('for image mode', async () => { + // Invoke the image embedder + const embeddingResult = imageEmbedder.embed({} as HTMLImageElement); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + + it('for video mode', async () => { + await imageEmbedder.setOptions({runningMode: 'video'}); + + // Invoke the video embedder + const embeddingResult = + imageEmbedder.embedForVideo({} as HTMLImageElement, 42); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index f73790895..fc206a2d7 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more object categories, using Object Detector. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -41,3 +42,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "object_detector_test_lib", + testonly = True, + srcs = [ + "object_detector_test.ts", + ], + deps = [ + ":object_detector", + ":object_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "object_detector_test", + tags = ["nomsan"], + deps = [":object_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts new file mode 100644 index 000000000..fff1a1c48 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -0,0 +1,229 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ObjectDetector} from './object_detector'; +import {ObjectDetectorOptions} from './object_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.ObjectDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ObjectDetector', () => { + let objectDetector: ObjectDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + objectDetector = new ObjectDetectorFake(); + await objectDetector.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(objectDetector); + verifyListenersRegistered(objectDetector); + }); + + it('reloads graph when settings are changed', async () => { + await objectDetector.setOptions({maxResults: 1}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyListenersRegistered(objectDetector); + + await objectDetector.setOptions({maxResults: 5}); + verifyGraph(objectDetector, ['maxResults', 5]); + verifyListenersRegistered(objectDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await objectDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + objectDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await objectDetector.setOptions({maxResults: 1}); + await objectDetector.setOptions({displayNamesLocale: 'en'}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyGraph(objectDetector, ['displayNamesLocale', 'en']); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionName: keyof ObjectDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + await objectDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + objectDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + const detectionProtos: Uint8Array[] = []; + + // Add a detection with all optional properties + let detection = new DetectionProto(); + detection.addScore(0.1); + detection.addLabelId(1); + detection.addLabel('foo'); + detection.addDisplayName('bar'); + let locationData = new LocationData(); + let boundingBox = new LocationData.BoundingBox(); + boundingBox.setXmin(1); + boundingBox.setYmin(2); + boundingBox.setWidth(3); + boundingBox.setHeight(4); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Add a detection without optional properties + detection = new DetectionProto(); + detection.addScore(0.2); + locationData = new LocationData(); + boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Pass the test data to our listener + objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(objectDetector); + objectDetector.protoListener!(detectionProtos); + }); + + // Invoke the object detector + const detections = objectDetector.detect({} as HTMLImageElement); + + expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(2); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: 1, + categoryName: 'foo', + displayName: 'bar', + }], + boundingBox: {originX: 1, originY: 2, width: 3, height: 4} + }); + expect(detections[1]).toEqual({ + categories: [{ + score: 0.2, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); From d1820320b15893a0f0b947ed208bdcfb630bb938 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 8 Dec 2022 10:23:53 +0530 Subject: [PATCH 199/469] Added base options --- mediapipe/tasks/ios/core/BUILD | 33 ++++++++++++ .../tasks/ios/core/sources/MPPBaseOptions.h | 51 +++++++++++++++++++ .../tasks/ios/core/sources/MPPBaseOptions.m | 36 +++++++++++++ .../tasks/ios/core/sources/MPPExternalFile.h | 28 ++++++++++ .../tasks/ios/core/sources/MPPExternalFile.m | 27 ++++++++++ 5 files changed, 175 insertions(+) create mode 100644 mediapipe/tasks/ios/core/BUILD create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.m create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.m diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD new file mode 100644 index 000000000..9b8ad7bec --- /dev/null +++ b/mediapipe/tasks/ios/core/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPExternalFile", + srcs = ["sources/MPPExternalFile.m"], + hdrs = ["sources/MPPExternalFile.h"], +) + +objc_library( + name = "MPPBaseOptions", + srcs = ["sources/MPPBaseOptions.m"], + hdrs = ["sources/MPPBaseOptions.h"], + deps = [ + ":MPPExternalFile", + + ], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h new file mode 100644 index 000000000..87b6826df --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks delegate. + */ +typedef NS_ENUM(NSUInteger, MPPDelegate) { + + /** CPU. */ + MPPDelegateCPU, + + /** GPU. */ + MPPDelegateGPU +} NS_SWIFT_NAME(Delegate); + +/** + * Holds the base options that is used for creation of any type of task. It has fields with + * important information acceleration configuration, TFLite model source etc. + */ +NS_SWIFT_NAME(BaseOptions) +@interface MPPBaseOptions : NSObject + +/** + * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model + * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated + * files might result in errors. + */ +@property(nonatomic, copy) MPPExternalFile *modelAssetFile; + +/** + * device delegate to run the MediaPipe pipeline. If the delegate is not set, the default + * delegate CPU is used. + */ +@property(nonatomic) MPPDelegate delegate; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m new file mode 100644 index 000000000..4c25b80e8 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPBaseOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.modelAssetFile = [[MPPExternalFile alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; + + baseOptions.modelAssetFile = self.modelAssetFile; + baseOptions.delegate = self.delegate; + + return baseOptions; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h new file mode 100644 index 000000000..a97802002 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds information about an external file. + */ +NS_SWIFT_NAME(ExternalFile) +@interface MPPExternalFile : NSObject + +/** Path to the file in bundle. */ +@property(nonatomic, copy) NSString *filePath; +/// Add provision for other sources in future. + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.m b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m new file mode 100644 index 000000000..70d85657c --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +@implementation MPPExternalFile + +- (id)copyWithZone:(NSZone *)zone { + MPPExternalFile *externalFile = [[MPPExternalFile alloc] init]; + + externalFile.filePath = self.filePath; + + return externalFile; +} + +@end From 66dbd9969a0aae7f71ce7096135a3b436ea76473 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 8 Dec 2022 10:25:01 +0530 Subject: [PATCH 200/469] Updated license text --- .../tasks/ios/core/sources/MPPBaseOptions.h | 25 +++++++++++-------- .../tasks/ios/core/sources/MPPExternalFile.h | 25 +++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index 87b6826df..258b49b3b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import #import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h index a97802002..300fd4778 100644 --- a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import NS_ASSUME_NONNULL_BEGIN From 13f8fa51393a2883ec825ad717bfffb693d59376 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 8 Dec 2022 07:59:46 -0800 Subject: [PATCH 201/469] Retire the visibility group "//mediapipe/framework:mediapipe_internal" in the "mediapipe/calculators/tensor" dir. PiperOrigin-RevId: 493895834 --- mediapipe/calculators/tensor/BUILD | 76 ++++-------------------------- 1 file changed, 8 insertions(+), 68 deletions(-) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 577ac4111..dec68deac 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -24,7 +24,7 @@ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) exports_files( glob(["testdata/image_to_tensor/*"]), @@ -44,9 +44,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "audio_to_tensor_calculator_proto", srcs = ["audio_to_tensor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -64,9 +61,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -113,9 +107,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_audio_calculator_proto", srcs = ["tensors_to_audio_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -125,9 +116,6 @@ mediapipe_proto_library( cc_library( name = "tensors_to_audio_calculator", srcs = ["tensors_to_audio_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":tensors_to_audio_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -164,9 +152,6 @@ cc_test( mediapipe_proto_library( name = "feedback_tensors_calculator_proto", srcs = ["feedback_tensors_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -184,9 +169,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":feedback_tensors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -216,9 +198,6 @@ cc_test( mediapipe_proto_library( name = "bert_preprocessor_calculator_proto", srcs = ["bert_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -228,9 +207,6 @@ mediapipe_proto_library( cc_library( name = "bert_preprocessor_calculator", srcs = ["bert_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":bert_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -274,9 +250,6 @@ cc_test( mediapipe_proto_library( name = "regex_preprocessor_calculator_proto", srcs = ["regex_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -286,9 +259,6 @@ mediapipe_proto_library( cc_library( name = "regex_preprocessor_calculator", srcs = ["regex_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":regex_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -330,9 +300,6 @@ cc_test( cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -366,9 +333,6 @@ cc_test( cc_library( name = "universal_sentence_encoder_preprocessor_calculator", srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -408,7 +372,6 @@ cc_test( mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -435,7 +398,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_options_lib", @@ -460,7 +422,6 @@ cc_library( name = "inference_calculator_gl", srcs = ["inference_calculator_gl.cc"], tags = ["nomac"], # config problem with cpuinfo via TF - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_interface", @@ -478,7 +439,6 @@ cc_library( name = "inference_calculator_gl_advanced", srcs = ["inference_calculator_gl_advanced.cc"], tags = ["nomac"], - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", "@com_google_absl//absl/memory", @@ -509,7 +469,6 @@ cc_library( "-framework MetalKit", ], tags = ["ios"], - visibility = ["//visibility:public"], deps = [ "inference_calculator_interface", "//mediapipe/gpu:MPPMetalHelper", @@ -538,7 +497,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:tensor", @@ -558,7 +516,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", @@ -588,7 +545,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -635,7 +591,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -651,7 +606,6 @@ cc_library( cc_library( name = "inference_calculator_gl_if_compute_shader_available", - visibility = ["//visibility:public"], deps = selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [ @@ -667,7 +621,6 @@ cc_library( # inference_calculator_interface. cc_library( name = "inference_calculator", - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_cpu", @@ -681,7 +634,6 @@ cc_library( mediapipe_proto_library( name = "tensor_converter_calculator_proto", srcs = ["tensor_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -706,7 +658,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensor_converter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -725,6 +676,7 @@ cc_library( cc_library( name = "tensor_converter_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:android": [ "//mediapipe/gpu:gl_calculator_helper", @@ -769,7 +721,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_detections_calculator_proto", srcs = ["tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -794,7 +745,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -817,6 +767,7 @@ cc_library( cc_library( name = "tensors_to_detections_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", @@ -832,7 +783,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_landmarks_calculator_proto", srcs = ["tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -849,7 +799,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -864,7 +813,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_tensor_calculator_proto", srcs = ["landmarks_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -882,7 +830,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":landmarks_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -915,7 +862,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_floats_calculator_proto", srcs = ["tensors_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -932,7 +878,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -970,7 +915,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1001,7 +945,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_classification_calculator_proto", srcs = ["tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1039,7 +982,6 @@ cc_library( "//conditions:default": [], }), features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", ":image_to_tensor_converter", @@ -1068,6 +1010,7 @@ cc_library( cc_library( name = "image_to_tensor_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = selects.with_or({ "//mediapipe:android": [ ":image_to_tensor_converter_gl_buffer", @@ -1091,7 +1034,6 @@ cc_library( mediapipe_proto_library( name = "image_to_tensor_calculator_proto", srcs = ["image_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1154,7 +1096,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_utils", "//mediapipe/framework/formats:image", @@ -1174,7 +1115,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_converter", ":image_to_tensor_utils", @@ -1194,6 +1134,7 @@ cc_library( name = "image_to_tensor_converter_gl_buffer", srcs = ["image_to_tensor_converter_gl_buffer.cc"], hdrs = ["image_to_tensor_converter_gl_buffer.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + selects.with_or({ "//mediapipe:apple": [], "//conditions:default": [ @@ -1227,6 +1168,7 @@ cc_library( name = "image_to_tensor_converter_gl_texture", srcs = ["image_to_tensor_converter_gl_texture.cc"], hdrs = ["image_to_tensor_converter_gl_texture.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1251,6 +1193,7 @@ cc_library( name = "image_to_tensor_converter_gl_utils", srcs = ["image_to_tensor_converter_gl_utils.cc"], hdrs = ["image_to_tensor_converter_gl_utils.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1280,6 +1223,7 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe:apple": [ ":image_to_tensor_converter", @@ -1311,7 +1255,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", "@com_google_absl//absl/status", @@ -1354,7 +1297,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "tensors_to_segmentation_calculator_proto", srcs = ["tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1372,7 +1314,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -1430,7 +1371,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", From a641ea12e15ec8e3ff552647ca569dc1ee9f59bc Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 8 Dec 2022 11:30:39 -0800 Subject: [PATCH 202/469] Update gesture recognizer to new mediapipe tasks pipeline PiperOrigin-RevId: 493950564 --- .../python/vision/gesture_recognizer/BUILD | 14 ++-- .../vision/gesture_recognizer/dataset.py | 62 ++++++++++------- .../vision/gesture_recognizer/dataset_test.py | 67 ++++++++----------- .../gesture_recognizer/metadata_writer.py | 62 +++++++++++++---- .../metadata_writer_test.py | 17 +++++ mediapipe/tasks/python/core/BUILD | 8 +-- mediapipe/tasks/python/vision/BUILD | 4 ++ 7 files changed, 147 insertions(+), 87 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 256447a8d..9123e36b0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -35,20 +35,21 @@ py_library( srcs = ["constants.py"], ) -# TODO: Change to py_library after migrating the MediaPipe hand solution -# library to MediaPipe hand task library. py_library( name = "dataset", srcs = ["dataset.py"], deps = [ ":constants", + ":metadata_writer", "//mediapipe/model_maker/python/core/data:classification_dataset", - "//mediapipe/model_maker/python/core/data:data_util", "//mediapipe/model_maker/python/core/utils:model_util", - "//mediapipe/python/solutions:hands", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "dataset_test", srcs = ["dataset_test.py"], @@ -56,10 +57,11 @@ py_test( ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], + tags = ["notsan"], deps = [ ":dataset", - "//mediapipe/python/solutions:hands", "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) @@ -131,6 +133,7 @@ py_library( ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "gesture_recognizer_test", size = "large", @@ -140,6 +143,7 @@ py_test( "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, + tags = ["notsan"], deps = [ ":gesture_recognizer_import", "//mediapipe/model_maker/python/core/utils:test_util", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 256f26fd6..6a2c878c0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -16,16 +16,22 @@ import dataclasses import os import random -from typing import List, NamedTuple, Optional +from typing import List, Optional -import cv2 import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset -from mediapipe.model_maker.python.core.data import data_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.vision.gesture_recognizer import constants -from mediapipe.python.solutions import hands as mp_hands +from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module + +_Image = image_module.Image +_HandLandmarker = hand_landmarker_module.HandLandmarker +_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions +_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult @dataclasses.dataclass @@ -59,7 +65,7 @@ class HandData: handedness: List[float] -def _validate_data_sample(data: NamedTuple) -> bool: +def _validate_data_sample(data: _HandLandmarkerResult) -> bool: """Validates the input hand data sample. Args: @@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool: 'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness' or any of these attributes' values are none. Otherwise, True. """ - if (not hasattr(data, 'multi_hand_landmarks') or - data.multi_hand_landmarks is None): + if data.hand_landmarks is None or not data.hand_landmarks: return False - if (not hasattr(data, 'multi_hand_world_landmarks') or - data.multi_hand_world_landmarks is None): + if data.hand_world_landmarks is None or not data.hand_world_landmarks: return False - if not hasattr(data, 'multi_handedness') or data.multi_handedness is None: + if data.handedness is None or not data.handedness: return False return True def _get_hand_data(all_image_paths: List[str], - min_detection_confidence: float) -> Optional[HandData]: + min_detection_confidence: float) -> List[Optional[HandData]]: """Computes hand data (landmarks and handedness) in the input image. Args: @@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str], A HandData object. Returns None if no hand is detected. """ hand_data_result = [] - with mp_hands.Hands( - static_image_mode=True, - max_num_hands=1, - min_detection_confidence=min_detection_confidence) as hands: + hand_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + hand_landmarker_options = _HandLandmarkerOptions( + base_options=base_options_module.BaseOptions( + model_asset_buffer=hand_landmarker_writer.populate()), + num_hands=1, + min_hand_detection_confidence=min_detection_confidence, + min_hand_presence_confidence=0.5, + min_tracking_confidence=1, + ) + with _HandLandmarker.create_from_options( + hand_landmarker_options) as hand_landmarker: for path in all_image_paths: tf.compat.v1.logging.info('Loading image %s', path) - image = data_util.load_image(path) - # Flip image around y-axis for correct handedness output - image = cv2.flip(image, 1) - data = hands.process(image) + image = _Image.create_from_file(path) + data = hand_landmarker.detect(image) if not _validate_data_sample(data): hand_data_result.append(None) continue - hand_landmarks = [[ - hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_landmarks[0].landmark] + hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z] + for hand_landmark in data.hand_landmarks[0]] hand_world_landmarks = [[ hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_world_landmarks[0].landmark] + ] for hand_landmark in data.hand_world_landmarks[0]] handedness_scores = [ - handedness.score - for handedness in data.multi_handedness[0].classification + handedness.score for handedness in data.handedness[0] ] hand_data_result.append( HandData( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py index 76e70a58d..528d02edd 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import os import shutil from typing import NamedTuple import unittest -from absl import flags from absl.testing import parameterized import tensorflow as tf from mediapipe.model_maker.python.vision.gesture_recognizer import dataset -from mediapipe.python.solutions import hands as mp_hands from mediapipe.tasks.python.test import test_utils - -FLAGS = flags.FLAGS +from mediapipe.tasks.python.vision import hand_landmarker _TEST_DATA_DIRNAME = 'raw_data' @@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) train_data, test_data = data.split(0.5) - self.assertLen(train_data, 17) + self.assertLen(train_data, 16) for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(train_data.num_classes, 4) self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock']) - self.assertLen(test_data, 18) + self.assertLen(test_data, 16) for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) @@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock']) @@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK']) @parameterized.named_parameters( dict( - testcase_name='invalid_field_name_multi_hand_landmark', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmark', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=None, hand_landmarks=[[2]], + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_hand_world_landmarks', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmark', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=None, + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_handed', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handed' - ])(1, 2, 3)), + testcase_name='none_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], + hand_world_landmarks=None)), dict( - testcase_name='multi_hand_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(None, 2, 3)), + testcase_name='empty_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_hand_world_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, None, 3)), + testcase_name='empty_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_handedness_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, None)), + testcase_name='empty_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])), ) def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): with unittest.mock.patch.object( - mp_hands.Hands, 'process', return_value=hand): + hand_landmarker.HandLandmarker, 'detect', return_value=hand): input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): dataset.Dataset.from_folder( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py index 58b67e072..b2e851afe 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]: return f.read() +class HandLandmarkerMetadataWriter: + """MetadataWriter to write the model asset bundle for HandLandmarker.""" + + def __init__( + self, + hand_detector_model_buffer: bytearray, + hand_landmarks_detector_model_buffer: bytearray, + ) -> None: + """Initializes HandLandmarkerMetadataWriter to write model asset bundle. + + Args: + hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from + the TFLite hand detector model file. + hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite hand landmarks detector model file. + """ + self._hand_detector_model_buffer = hand_detector_model_buffer + self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._temp_folder = tempfile.TemporaryDirectory() + + def __del__(self): + if os.path.exists(self._temp_folder.name): + self._temp_folder.cleanup() + + def populate(self): + """Creates the model asset bundle for hand landmarker task. + + Returns: + Model asset bundle in bytes + """ + landmark_models = { + _HAND_DETECTOR_TFLITE_NAME: + self._hand_detector_model_buffer, + _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: + self._hand_landmarks_detector_model_buffer + } + output_hand_landmarker_path = os.path.join(self._temp_folder.name, + _HAND_LANDMARKER_BUNDLE_NAME) + writer_utils.create_model_asset_bundle(landmark_models, + output_hand_landmarker_path) + hand_landmarker_model_buffer = read_file(output_hand_landmarker_path) + return hand_landmarker_model_buffer + + class MetadataWriter: """MetadataWriter to write the metadata and the model asset bundle.""" @@ -86,8 +130,8 @@ class MetadataWriter: custom_gesture_classifier_metadata_writer: Metadata writer to write custom gesture classifier metadata into the TFLite file. """ - self._hand_detector_model_buffer = hand_detector_model_buffer - self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) self._gesture_embedder_model_buffer = gesture_embedder_model_buffer self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer @@ -147,16 +191,8 @@ class MetadataWriter: A tuple of (model_asset_bundle_in_bytes, metadata_json_content) """ # Creates the model asset bundle for hand landmarker task. - landmark_models = { - _HAND_DETECTOR_TFLITE_NAME: - self._hand_detector_model_buffer, - _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: - self._hand_landmarks_detector_model_buffer - } - output_hand_landmarker_path = os.path.join(self._temp_folder.name, - _HAND_LANDMARKER_BUNDLE_NAME) - writer_utils.create_model_asset_bundle(landmark_models, - output_hand_landmarker_path) + hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate( + ) # Write metadata into custom gesture classifier model. self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate( @@ -179,7 +215,7 @@ class MetadataWriter: # graph. gesture_recognizer_models = { _HAND_LANDMARKER_BUNDLE_NAME: - read_file(output_hand_landmarker_path), + hand_landmarker_model_buffer, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME: read_file(output_hand_gesture_recognizer_path), } diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index 83998141d..fd26b274d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path( class MetadataWriterTest(tf.test.TestCase): + def test_hand_landmarker_metadata_writer(self): + # Use dummy model buffer for unit test only. + hand_detector_model_buffer = b"\x11\x12" + hand_landmarks_detector_model_buffer = b"\x22" + writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + model_bundle_content = writer.populate() + model_bundle_filepath = os.path.join(self.get_temp_dir(), + "hand_landmarker.task") + with open(model_bundle_filepath, "wb") as f: + f.write(model_bundle_content) + + with zipfile.ZipFile(model_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set(["hand_landmarks_detector.tflite", "hand_detector.tflite"])) + def test_write_metadata_and_create_model_asset_bundle_successful(self): # Use dummy model buffer for unit test only. hand_detector_model_buffer = b"\x11\x12" diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 447189d6f..f14d59b99 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -23,15 +23,15 @@ py_library( srcs = [ "optional_dependencies.py", ], - deps = [ - "@org_tensorflow//tensorflow/tools/docs:doc_controls", - ], ) py_library( name = "base_options", srcs = ["base_options.py"], - visibility = ["//mediapipe/tasks:users"], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:users", + ], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 5f4aa38ff..eda8e290d 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -131,6 +131,10 @@ py_library( srcs = [ "hand_landmarker.py", ], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2", From 0fbaa8dc8a0220d682081d70fd01bef71709a316 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 8 Dec 2022 12:59:46 -0800 Subject: [PATCH 203/469] Internal change. PiperOrigin-RevId: 493973435 --- .../framework/formats/tensor_ahwb_gpu_test.cc | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 mediapipe/framework/formats/tensor_ahwb_gpu_test.cc diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc new file mode 100644 index 000000000..dd865a367 --- /dev/null +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -0,0 +1,143 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor_data_types.h" +#include "mediapipe/gpu/gpu_test_base.h" +#include "mediapipe/gpu/shader_util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "testing/base/public/gunit.h" + +// The test creates OpenGL ES buffer, fills the buffer with incrementing values +// 0.0, 0.1, 0.2 etc. with the compute shader on GPU. +// Then the test requests the CPU view and compares the values. +// Float32 and Float16 tests are there. + +namespace { + +using mediapipe::Float16; +using mediapipe::Tensor; + +MATCHER_P(NearWithPrecision, precision, "") { + return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; +} + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +// Utility function to fill the GPU buffer. +void FillGpuBuffer(GLuint name, std::size_t size, + const Tensor::ElementType fmt) { + std::string shader_source; + if (fmt == Tensor::ElementType::kFloat32) { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x * 2u; + output_data.elements[v] = float(v) / 10.0; + output_data.elements[v + 1u] = float(v + 1u) / 10.0; + })"; + } else { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x; + uint tmp = packHalf2x16(vec2((float(v)* 2.0 + 0.0) / 10.0, + (float(v) * 2.0 + 1.0) / 10.0)); + output_data.elements[v] = uintBitsToFloat(tmp); + })"; + } + GLuint shader; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateShader, &shader, GL_COMPUTE_SHADER)); + const GLchar* sources[] = {shader_source.c_str()}; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glShaderSource, shader, 1, sources, nullptr)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCompileShader, shader)); + GLint is_compiled = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_COMPILE_STATUS, + &is_compiled)); + if (is_compiled == GL_FALSE) { + GLint max_length = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, + &max_length)); + std::vector error_log(max_length); + glGetShaderInfoLog(shader, max_length, &max_length, error_log.data()); + glDeleteShader(shader); + FAIL() << error_log.data(); + return; + } + GLuint to_buffer_program; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateProgram, &to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glAttachShader, to_buffer_program, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glLinkProgram, to_buffer_program)); + + MP_ASSERT_OK( + TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); +} + +class TensorAhwbGpuTest : public mediapipe::GpuTestBase { + public: +}; + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + // Precision is set to a reasonable value for Float16. + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(NearWithPrecision(0.001), reference)); +} + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +} // namespace + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From b4e1969e4381053038322275bfb8f15d855da9f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 8 Dec 2022 14:01:17 -0800 Subject: [PATCH 204/469] Add pip package builder for model_maker PiperOrigin-RevId: 493989013 --- mediapipe/model_maker/MANIFEST.in | 1 + mediapipe/model_maker/__init__.py | 6 + .../python/core/utils/file_util.py | 19 ++- mediapipe/model_maker/python/text/__init__.py | 13 ++ mediapipe/model_maker/requirements.txt | 6 +- mediapipe/model_maker/setup.py | 147 ++++++++++++++++++ 6 files changed, 184 insertions(+), 8 deletions(-) create mode 100644 mediapipe/model_maker/MANIFEST.in create mode 100644 mediapipe/model_maker/python/text/__init__.py create mode 100644 mediapipe/model_maker/setup.py diff --git a/mediapipe/model_maker/MANIFEST.in b/mediapipe/model_maker/MANIFEST.in new file mode 100644 index 000000000..54ce01aff --- /dev/null +++ b/mediapipe/model_maker/MANIFEST.in @@ -0,0 +1 @@ +recursive-include pip_src/mediapipe_model_maker/models * diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 7ca2f9216..9899a145b 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.text import text_classifier diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index bccf928e2..66addad54 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -19,7 +19,7 @@ import os def get_absolute_path(file_path: str) -> str: - """Gets the absolute path of a file. + """Gets the absolute path of a file in the model_maker directory. Args: file_path: The path to a file relative to the `mediapipe` dir @@ -27,10 +27,17 @@ def get_absolute_path(file_path: str) -> str: Returns: The full path of the file """ - # Extract the file path before mediapipe/ as the `base_dir`. By joining it - # with the `path` which defines the relative path under mediapipe/, it - # yields to the absolute path of the model files directory. + # Extract the file path before and including 'model_maker' as the + # `mm_base_dir`. By joining it with the `path` after 'model_maker/', it + # yields to the absolute path of the model files directory. We must join + # on 'model_maker' because in the pypi package, the 'model_maker' directory + # is renamed to 'mediapipe_model_maker'. So we have to join on model_maker + # to ensure that the `mm_base_dir` path includes the renamed + # 'mediapipe_model_maker' directory. cwd = os.path.dirname(__file__) - base_dir = cwd[:cwd.rfind('mediapipe')] - absolute_path = os.path.join(base_dir, file_path) + cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker') + mm_base_dir = cwd[:cwd_stop_idx] + file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1 + mm_relative_path = file_path[file_path_start_idx:] + absolute_path = os.path.join(mm_base_dir, mm_relative_path) return absolute_path diff --git a/mediapipe/model_maker/python/text/__init__.py b/mediapipe/model_maker/python/text/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 389ee484a..9b3c9f906 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,6 +1,8 @@ absl-py +mediapipe==0.9.1 numpy -opencv-contrib-python -tensorflow +opencv-python +tensorflow>=2.10 tensorflow-datasets tensorflow-hub +tf-models-official>=2.10.1 diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py new file mode 100644 index 000000000..ea193db94 --- /dev/null +++ b/mediapipe/model_maker/setup.py @@ -0,0 +1,147 @@ +"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Setup for Mediapipe-Model-Maker package with setuptools. +""" + +import glob +import os +import shutil +import subprocess +import sys +import setuptools + + +__version__ = 'dev' +MM_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) +# Build dir to copy all necessary files and build package +SRC_NAME = 'pip_src' +BUILD_DIR = os.path.join(MM_ROOT_PATH, SRC_NAME) +BUILD_MM_DIR = os.path.join(BUILD_DIR, 'mediapipe_model_maker') + + +def _parse_requirements(path): + with open(os.path.join(MM_ROOT_PATH, path)) as f: + return [ + line.rstrip() + for line in f + if not (line.isspace() or line.startswith('#')) + ] + + +def _copy_to_pip_src_dir(file): + """Copy a file from bazel-bin to the pip_src dir.""" + dst = file + dst_dir = os.path.dirname(dst) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file) + shutil.copyfile(src_file, file) + + +def _setup_build_dir(): + """Setup the BUILD_DIR directory to build the mediapipe_model_maker package. + + We need to create a new BUILD_DIR directory because any references to the path + `mediapipe/model_maker` needs to be renamed to `mediapipe_model_maker` to + avoid conflicting with the mediapipe package name. + This setup function performs the following actions: + 1. Copy python source code into BUILD_DIR and rename imports to + mediapipe_model_maker + 2. Download models from GCS into BUILD_DIR + """ + # Copy python source code into BUILD_DIR + if os.path.exists(BUILD_DIR): + shutil.rmtree(BUILD_DIR) + python_files = glob.glob('python/**/*.py', recursive=True) + python_files.append('__init__.py') + for python_file in python_files: + # Exclude test files from pip package + if '_test.py' in python_file: + continue + build_target_file = os.path.join(BUILD_MM_DIR, python_file) + with open(python_file, 'r') as file: + filedata = file.read() + # Rename all mediapipe.model_maker imports to mediapipe_model_maker + filedata = filedata.replace('from mediapipe.model_maker', + 'from mediapipe_model_maker') + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + with open(build_target_file, 'w') as file: + file.write(filedata) + + # Use bazel to download GCS model files + model_build_files = ['models/gesture_recognizer/BUILD'] + for model_build_file in model_build_files: + build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + shutil.copy(model_build_file, build_target_file) + external_files = [ + 'models/gesture_recognizer/canned_gesture_classifier.tflite', + 'models/gesture_recognizer/gesture_embedder.tflite', + 'models/gesture_recognizer/hand_landmark_full.tflite', + 'models/gesture_recognizer/palm_detection_full.tflite', + 'models/gesture_recognizer/gesture_embedder/keras_metadata.pb', + 'models/gesture_recognizer/gesture_embedder/saved_model.pb', + 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', + 'models/gesture_recognizer/gesture_embedder/variables/variables.index', + ] + for elem in external_files: + external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) + sys.stderr.write('downloading file: %s\n' % external_file) + fetch_model_command = [ + 'bazel', + 'build', + external_file, + ] + if subprocess.call(fetch_model_command) != 0: + sys.exit(-1) + _copy_to_pip_src_dir(external_file) + +_setup_build_dir() + +setuptools.setup( + name='mediapipe-model-maker', + version=__version__, + url='https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + description='MediaPipe Model Maker is a simple, low-code solution for customizing on-device ML models', + author='The MediaPipe Authors', + author_email='mediapipe@google.com', + long_description='', + long_description_content_type='text/markdown', + packages=setuptools.find_packages(where=SRC_NAME), + package_dir={'': SRC_NAME}, + install_requires=_parse_requirements('requirements.txt'), + include_package_data=True, + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3 :: Only', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + license='Apache 2.0', + keywords=['mediapipe', 'model', 'maker'], +) From 05535db5f77bdc9c46df36855fe4064ded89d7cb Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 8 Dec 2022 15:01:34 -0800 Subject: [PATCH 205/469] Fix assertion failure in Hair Segmentation demo PiperOrigin-RevId: 494004801 --- mediapipe/graphs/hair_segmentation/BUILD | 1 + .../hair_segmentation_desktop_live.pbtxt | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mediapipe/graphs/hair_segmentation/BUILD b/mediapipe/graphs/hair_segmentation/BUILD index b177726bf..945f02c62 100644 --- a/mediapipe/graphs/hair_segmentation/BUILD +++ b/mediapipe/graphs/hair_segmentation/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:color_convert_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:set_alpha_calculator", diff --git a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt index 36c6970e1..f48b26be0 100644 --- a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt +++ b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt @@ -60,7 +60,14 @@ node { tag_index: "LOOP" back_edge: true } - output_stream: "PREV_LOOP:previous_hair_mask" + output_stream: "PREV_LOOP:previous_hair_mask_rgb" +} + +# Converts the 4 channel hair mask to a single channel mask +node { + calculator: "ColorConvertCalculator" + input_stream: "RGB_IN:previous_hair_mask_rgb" + output_stream: "GRAY_OUT:previous_hair_mask" } # Embeds the hair mask generated from the previous round of hair segmentation From bea0caae6586343eb91e986b84aa00ef75fa67b1 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 8 Dec 2022 17:05:06 -0800 Subject: [PATCH 206/469] Tensor: Cpu -> Ahwb storage transfer PiperOrigin-RevId: 494033280 --- mediapipe/framework/formats/tensor.cc | 4 +-- mediapipe/framework/formats/tensor.h | 2 +- mediapipe/framework/formats/tensor_ahwb.cc | 3 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 28 +++++++++++++++++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 9e1406dbb..fdafbff5c 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -361,7 +361,7 @@ void Tensor::AllocateOpenGlBuffer() const { LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenBuffers(1, &opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); - if (!AllocateAhwbMapToSsbo()) { + if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) { glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); } glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); @@ -610,7 +610,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView() const { void Tensor::AllocateCpuBuffer() const { if (!cpu_buffer_) { #ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (AllocateAHardwareBuffer()) return; + if (use_ahwb_ && AllocateAHardwareBuffer()) return; #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_METAL_ENABLED cpu_buffer_ = AllocateVirtualMemory(bytes()); diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 151aa299d..9d3e90b6a 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -409,8 +409,8 @@ class Tensor { bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; // Use Ahwb for other views: OpenGL / CPU buffer. - static inline bool use_ahwb_ = false; #endif // MEDIAPIPE_TENSOR_USE_AHWB + static inline bool use_ahwb_ = false; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 21bae9593..3c3ec8b17 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -214,7 +214,7 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { "supported."; CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on targe system."; + "supported on target system."; bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; @@ -268,7 +268,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { - if (!use_ahwb_) return false; if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index dd865a367..7ccd9c7f5 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -136,6 +136,34 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { testing::Pointwise(NearWithPrecision(0.001), reference)); } +TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { + // Request the CPU view to get the memory to be allocated. + // Request Ahwb view then to transform the storage into Ahwb. + Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + auto ptr = tensor.GetCpuWriteView().buffer(); + EXPECT_NE(ptr, nullptr); + for (int i = 0; i < num_elements; i++) { + ptr[i] = static_cast(i) / 10.0f; + } + } + { + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + } + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 } // namespace From 3aeec84ac016d9899a7829ad5651753942dcf275 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 03:19:45 -0800 Subject: [PATCH 207/469] Internal change for profiling PiperOrigin-RevId: 494126771 --- mediapipe/framework/validated_graph_config.cc | 10 ++++++++++ mediapipe/framework/validated_graph_config.h | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 01e3da83e..15eac3209 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize( input_side_packets_.clear(); output_side_packets_.clear(); stream_to_producer_.clear(); + output_streams_to_consumer_nodes_.clear(); input_streams_.clear(); output_streams_.clear(); owned_packet_types_.clear(); @@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( << " does not have a corresponding output stream."; } } + // Add this node as a consumer of this edge's output stream. + if (edge_info.upstream > -1) { + auto parent_node = output_streams_[edge_info.upstream].parent_node; + if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) { + int this_idx = node_type_info->Node().index; + output_streams_to_consumer_nodes_[edge_info.upstream].push_back( + this_idx); + } + } edge_info.parent_node = node_type_info->Node(); edge_info.name = name; diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index 11f9553cd..95ecccbb4 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -282,6 +282,14 @@ class ValidatedGraphConfig { return output_streams_[iter->second].parent_node.index; } + std::vector OutputStreamToConsumers(int idx) const { + auto iter = output_streams_to_consumer_nodes_.find(idx); + if (iter == output_streams_to_consumer_nodes_.end()) { + return {}; + } + return iter->second; + } + // Returns the registered type name of the specified side packet if // it can be determined, otherwise an appropriate error is returned. absl::StatusOr RegisteredSidePacketTypeName( @@ -418,6 +426,10 @@ class ValidatedGraphConfig { // Mapping from stream name to the output_streams_ index which produces it. std::map stream_to_producer_; + + // Mapping from output streams to consumer node ids. Used for profiling. + std::map> output_streams_to_consumer_nodes_; + // Mapping from side packet name to the output_side_packets_ index // which produces it. std::map side_packet_to_producer_; From 4c4df2cf18955b2cc76f432b5f1321083d5bfb11 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 04:11:05 -0800 Subject: [PATCH 208/469] Internal change for profiling PiperOrigin-RevId: 494135244 --- mediapipe/framework/profiler/graph_profiler.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 29969af2e..23caed4ec 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -232,6 +232,11 @@ class GraphProfiler : public std::enable_shared_from_this { const ProfilerConfig& profiler_config() { return profiler_config_; } + // Helper method to expose the config to other profilers. + const ValidatedGraphConfig* GetValidatedGraphConfig() { + return validated_graph_; + } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a From 5bc1baf96acab858942d151d46b988ebe0577c00 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 05:55:20 -0800 Subject: [PATCH 209/469] Internal change PiperOrigin-RevId: 494150467 --- mediapipe/framework/output_stream_shard.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/framework/output_stream_shard.h b/mediapipe/framework/output_stream_shard.h index fdc5fe077..718174c45 100644 --- a/mediapipe/framework/output_stream_shard.h +++ b/mediapipe/framework/output_stream_shard.h @@ -127,6 +127,8 @@ class OutputStreamShard : public OutputStream { friend class GraphProfiler; // Accesses OutputStreamShard for profiling. friend class GraphTracer; + // Accesses OutputStreamShard for profiling. + friend class PerfettoTraceScope; // Accesses OutputStreamShard for post processing. friend class OutputStreamManager; }; From db3cb68d919693adb729437d1223d29c30736f27 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Fri, 9 Dec 2022 07:27:11 -0800 Subject: [PATCH 210/469] Internal change. PiperOrigin-RevId: 494166776 --- .../formats/tensor_hardware_buffer.h | 71 ++++++ .../tensor_hardware_buffer_cpu_storage.cc | 216 ++++++++++++++++++ ...tensor_hardware_buffer_cpu_storage_test.cc | 76 ++++++ 3 files changed, 363 insertions(+) create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer.h create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc diff --git a/mediapipe/framework/formats/tensor_hardware_buffer.h b/mediapipe/framework/formats/tensor_hardware_buffer.h new file mode 100644 index 000000000..fa0241bde --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer.h @@ -0,0 +1,71 @@ +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include + +#include "mediapipe/framework/formats/tensor_buffer.h" +#include "mediapipe/framework/formats/tensor_internal.h" +#include "mediapipe/framework/formats/tensor_v2.h" + +namespace mediapipe { + +// Supports: +// - float 16 and 32 bits +// - signed / unsigned integers 8,16,32 bits +class TensorHardwareBufferView; +struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor { + using ViewT = TensorHardwareBufferView; + TensorBufferDescriptor buffer; +}; + +class TensorHardwareBufferView : public Tensor::View { + public: + TENSOR_UNIQUE_VIEW_TYPE_ID(); + ~TensorHardwareBufferView() = default; + + const TensorHardwareBufferViewDescriptor& descriptor() const override { + return descriptor_; + } + AHardwareBuffer* handle() const { return ahwb_handle_; } + + protected: + TensorHardwareBufferView(int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& desc, + AHardwareBuffer* ahwb_handle) + : Tensor::View(kId, access_capability, access, state), + descriptor_(desc), + ahwb_handle_(ahwb_handle) {} + + private: + bool MatchDescriptor( + uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) const override { + if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor)) + return false; + auto descriptor = + static_cast(base_descriptor); + return descriptor.buffer.format == descriptor_.buffer.format && + descriptor.buffer.size_alignment <= + descriptor_.buffer.size_alignment && + descriptor_.buffer.size_alignment % + descriptor.buffer.size_alignment == + 0; + } + const TensorHardwareBufferViewDescriptor& descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; +}; + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc new file mode 100644 index 000000000..9c223ce2c --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc @@ -0,0 +1,216 @@ +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "mediapipe/framework/formats/tensor_backend.h" +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "util/task/status_macros.h" + +namespace mediapipe { +namespace { + +class TensorCpuViewImpl : public TensorCpuView { + public: + TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access, + Tensor::View::State state, + const TensorCpuViewDescriptor& descriptor, void* pointer, + AHardwareBuffer* ahwb_handle) + : TensorCpuView(access_capabilities, access, state, descriptor, pointer), + ahwb_handle_(ahwb_handle) {} + ~TensorCpuViewImpl() { + // If handle_ is null then this view is constructed in GetViews with no + // access. + if (ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_unlock(ahwb_handle_, nullptr); + } + } + } + + private: + AHardwareBuffer* ahwb_handle_; +}; + +class TensorHardwareBufferViewImpl : public TensorHardwareBufferView { + public: + TensorHardwareBufferViewImpl( + int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& descriptor, + AHardwareBuffer* handle) + : TensorHardwareBufferView(access_capability, access, state, descriptor, + handle) {} + ~TensorHardwareBufferViewImpl() = default; +}; + +class HardwareBufferCpuStorage : public TensorStorage { + public: + ~HardwareBufferCpuStorage() { + if (!ahwb_handle_) return; + if (__builtin_available(android 26, *)) { + AHardwareBuffer_release(ahwb_handle_); + } + } + + static absl::Status CanProvide( + int access_capability, const Tensor::Shape& shape, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) { + // TODO: use AHardwareBuffer_isSupported for API >= 29. + static const bool is_ahwb_supported = [] { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + // Aligned to the largest possible virtual memory page size. + constexpr uint32_t kPageSize = 16384; + desc.width = kPageSize; + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + AHardwareBuffer* handle; + if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false; + AHardwareBuffer_release(handle); + return true; + } + return false; + }(); + if (!is_ahwb_supported) { + return absl::UnavailableError( + "AHardwareBuffer is not supported on the platform."); + } + + if (view_type_id != TensorCpuView::kId && + view_type_id != TensorHardwareBufferView::kId) { + return absl::InvalidArgumentError( + "A view type is not supported by this storage."); + } + return absl::OkStatus(); + } + + std::vector> GetViews(uint64_t latest_version) { + std::vector> result; + auto update_state = latest_version == version_ + ? Tensor::View::State::kUpToDate + : Tensor::View::State::kOutdated; + if (ahwb_handle_) { + result.push_back( + std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + hw_descriptor_, ahwb_handle_))); + + result.push_back(std::unique_ptr(new TensorCpuViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + cpu_descriptor_, nullptr, nullptr))); + } + return result; + } + + absl::StatusOr> GetView( + Tensor::View::Access access, const Tensor::Shape& shape, + uint64_t latest_version, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor, int access_capability) { + MP_RETURN_IF_ERROR( + CanProvide(access_capability, shape, view_type_id, base_descriptor)); + const auto& buffer_descriptor = + view_type_id == TensorHardwareBufferView::kId + ? static_cast( + base_descriptor) + .buffer + : static_cast(base_descriptor) + .buffer; + if (!ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + desc.width = TensorBufferSize(buffer_descriptor, shape); + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + // TODO: Use access capabilities to set hints. + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error allocating hardware buffer: ", error)); + } + // Fill all possible views to provide it as proto views. + hw_descriptor_.buffer = buffer_descriptor; + cpu_descriptor_.buffer = buffer_descriptor; + } + } + if (buffer_descriptor.format != hw_descriptor_.buffer.format || + buffer_descriptor.size_alignment > + hw_descriptor_.buffer.size_alignment || + hw_descriptor_.buffer.size_alignment % + buffer_descriptor.size_alignment > + 0) { + return absl::AlreadyExistsError( + "A view with different params is already allocated with this " + "storage"); + } + + absl::StatusOr> result; + if (view_type_id == TensorHardwareBufferView::kId) { + result = GetAhwbView(access, shape, base_descriptor); + } else { + result = GetCpuView(access, shape, base_descriptor); + } + if (result.ok()) version_ = latest_version; + return result; + } + + private: + absl::StatusOr> GetAhwbView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + return std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, access, Tensor::View::State::kUpToDate, + hw_descriptor_, ahwb_handle_)); + } + + absl::StatusOr> GetCpuView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + void* pointer = nullptr; + if (__builtin_available(android 26, *)) { + int error = + AHardwareBuffer_lock(ahwb_handle_, + access == Tensor::View::Access::kWriteOnly + ? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN + : AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, + -1, nullptr, &pointer); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error locking hardware buffer: ", error)); + } + } + return std::unique_ptr( + new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly + ? Tensor::View::AccessCapability::kWrite + : Tensor::View::AccessCapability::kRead, + access, Tensor::View::State::kUpToDate, + cpu_descriptor_, pointer, ahwb_handle_)); + } + + static constexpr int kAccessCapability = + Tensor::View::AccessCapability::kRead | + Tensor::View::AccessCapability::kWrite; + TensorHardwareBufferViewDescriptor hw_descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; + + TensorCpuViewDescriptor cpu_descriptor_; + uint64_t version_ = 0; +}; +TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage); + +} // namespace +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc new file mode 100644 index 000000000..0afa9899f --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc @@ -0,0 +1,76 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { + +namespace { + +class TensorHardwareBufferTest : public ::testing::Test { + public: + TensorHardwareBufferTest() {} + ~TensorHardwareBufferTest() override {} +}; + +TEST_F(TensorHardwareBufferTest, TestFloat32) { + Tensor tensor{Tensor::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->data(), nullptr); + } +} + +TEST_F(TensorHardwareBufferTest, TestInt8Padding) { + Tensor tensor{Tensor::Shape({1})}; + + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8, + .size_alignment = 4}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + EXPECT_NE(view->data(), nullptr); + } +} + +} // namespace + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From 453d67de92d19abf2488a4400532708d734b20bb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 9 Dec 2022 13:10:25 -0800 Subject: [PATCH 211/469] Add MergeDetectionsToVectorCalculator. PiperOrigin-RevId: 494246359 --- mediapipe/calculators/core/BUILD | 1 + mediapipe/calculators/core/merge_to_vector_calculator.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 29bca5fa6..2c143a609 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1323,6 +1323,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "@com_google_absl//absl/status", ], diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index 5f05ad725..fd053ed2b 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/calculators/core/merge_to_vector_calculator.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" namespace mediapipe { @@ -27,5 +28,9 @@ typedef MergeToVectorCalculator MergeGpuBuffersToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); +typedef MergeToVectorCalculator + MergeDetectionsToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeDetectionsToVectorCalculator); + } // namespace api2 } // namespace mediapipe From 69c3c4c181766e4e94bca9f1db6ce49315d8ac45 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 9 Dec 2022 18:08:26 -0800 Subject: [PATCH 212/469] Internal change PiperOrigin-RevId: 494305195 --- mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts | 1 + mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts | 1 + mediapipe/tasks/web/core/task_runner.ts | 1 + mediapipe/tasks/web/text/text_classifier/text_classifier.ts | 1 + mediapipe/tasks/web/text/text_embedder/text_embedder.ts | 1 + .../tasks/web/vision/gesture_recognizer/gesture_recognizer.ts | 1 + mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts | 1 + mediapipe/tasks/web/vision/image_classifier/image_classifier.ts | 1 + mediapipe/tasks/web/vision/image_embedder/image_embedder.ts | 1 + mediapipe/tasks/web/vision/object_detector/object_detector.ts | 1 + 10 files changed, 10 insertions(+) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 265ba2b33..7bfca680a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -94,6 +94,7 @@ export class AudioClassifier extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 445dd5172..246cba883 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -96,6 +96,7 @@ export class AudioEmbedder extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 6712c4d89..2011fadef 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -76,6 +76,7 @@ export abstract class TaskRunner { return createTaskRunner(type, initializeCanvas, fileset, options); } + /** @hideconstructor protected */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, graphRunner?: GraphRunnerImageLib) { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 4a8588836..62708700a 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -92,6 +92,7 @@ export class TextClassifier extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index cd5bc644e..611233e02 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -96,6 +96,7 @@ export class TextEmbedder extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 69a8118a6..b6b795076 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -127,6 +127,7 @@ export class GestureRecognizer extends {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 9a0823f23..2a0e8286c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -115,6 +115,7 @@ export class HandLandmarker extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 40e8b5099..36e7311fb 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -93,6 +93,7 @@ export class ImageClassifier extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index f8b0204ee..0c45ba5e7 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -95,6 +95,7 @@ export class ImageEmbedder extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e2cfe0575..fbfaced12 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -92,6 +92,7 @@ export class ObjectDetector extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { From edafef9fd8bfb34e91d8578f5ad68919b8cff702 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Fri, 9 Dec 2022 18:08:41 -0800 Subject: [PATCH 213/469] Updated issue templates. PiperOrigin-RevId: 494305235 --- .github/ISSUE_TEMPLATE/11-tasks-issue.md | 2 +- .github/ISSUE_TEMPLATE/12-model-maker-issue.md | 2 +- .../{10-solution-issue.md => 13-solution-issue.md} | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename .github/ISSUE_TEMPLATE/{10-solution-issue.md => 13-solution-issue.md} (81%) diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md index 264371120..4e9ae721d 100644 --- a/.github/ISSUE_TEMPLATE/11-tasks-issue.md +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -1,6 +1,6 @@ --- name: "Tasks Issue" -about: Use this template for assistance with using MediaPipe Tasks to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. +about: Use this template for assistance with using MediaPipe Tasks (developers.google.com/mediapipe/solutions) to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md index 258390d5e..31e8d7f1b 100644 --- a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -1,6 +1,6 @@ --- name: "Model Maker Issue" -about: Use this template for assistance with using MediaPipe Model Maker to create custom on-device ML solutions. +about: Use this template for assistance with using MediaPipe Model Maker (developers.google.com/mediapipe/solutions) to create custom on-device ML solutions. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/10-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md similarity index 81% rename from .github/ISSUE_TEMPLATE/10-solution-issue.md rename to .github/ISSUE_TEMPLATE/13-solution-issue.md index a5332cb36..9297edf6b 100644 --- a/.github/ISSUE_TEMPLATE/10-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- -name: "Solution Issue" -about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +name: "Solution (legacy) Issue" +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. labels: type:support --- From e9bb51a524bc3c9e38aa7e689020172bea678069 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 9 Dec 2022 19:19:49 -0800 Subject: [PATCH 214/469] Internal change PiperOrigin-RevId: 494314595 --- .../mediapipe/apps/instantmotiontracking/GIFEditText.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 10e6422ba..1b733ed82 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import android.support.v7.widget.AppCompatEditText; +import androidx.appcompat.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; From 421f789edea501d5fbfd7078d2d9534a628dd886 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Sat, 10 Dec 2022 12:32:04 -0800 Subject: [PATCH 215/469] Internal change PiperOrigin-RevId: 494420725 --- mediapipe/framework/tool/BUILD | 2 + .../tool/calculator_graph_template.proto | 3 + mediapipe/framework/tool/proto_util_lite.cc | 103 +++++++++---- mediapipe/framework/tool/proto_util_lite.h | 28 +++- mediapipe/framework/tool/template_expander.cc | 136 ++++++++++++------ mediapipe/framework/tool/template_parser.cc | 128 ++++++++++++++++- mediapipe/framework/tool/testdata/BUILD | 10 ++ .../tool/testdata/frozen_generator.proto | 20 +++ 8 files changed, 348 insertions(+), 82 deletions(-) create mode 100644 mediapipe/framework/tool/testdata/frozen_generator.proto diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 453b5a0e8..89cb802da 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -346,6 +346,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", "@com_google_absl//absl/strings", ], ) @@ -506,6 +507,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":proto_util_lite", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:integral_types", diff --git a/mediapipe/framework/tool/calculator_graph_template.proto b/mediapipe/framework/tool/calculator_graph_template.proto index 27153f3f7..31c233812 100644 --- a/mediapipe/framework/tool/calculator_graph_template.proto +++ b/mediapipe/framework/tool/calculator_graph_template.proto @@ -27,6 +27,9 @@ message TemplateExpression { // The FieldDescriptor::Type of the modified field. optional mediapipe.FieldDescriptorProto.Type field_type = 5; + // The FieldDescriptor::Type of each map key in the path. + repeated mediapipe.FieldDescriptorProto.Type key_type = 6; + // Alternative value for the modified field, in protobuf binary format. optional string field_value = 7; } diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index 4628815ea..a810ce129 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/field_data.pb.h" #include "mediapipe/framework/type_map.h" @@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, - CodedInputStream* in, CodedOutputStream* out, +absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, + CodedOutputStream* out, std::vector* field_values) { uint32 tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); + WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (field_number == field_id) { if (!IsLengthDelimited(wire_type) && IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) { @@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) { CodedInputStream in(&ais); StringOutputStream sos(&message_); CodedOutputStream out(&sos); - WireFormatLite::WireType wire_type = - WireFormatLite::WireTypeForFieldType(field_type_); - return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_); + return GetFieldValues(field_id_, &in, &out, &field_values_); } void FieldAccess::GetMessage(std::string* result) { @@ -149,18 +149,56 @@ std::vector* FieldAccess::mutable_field_values() { return &field_values_; } +namespace { +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; + +// Returns the FieldAccess and index for a field-id or a map-id. +// Returns access to the field-id if the field index is found, +// to the map-id if the map entry is found, and to the field-id otherwise. +absl::StatusOr> AccessField( + const ProtoPathEntry& entry, FieldType field_type, + const FieldValue& message) { + FieldAccess result(entry.field_id, field_type); + if (entry.field_id >= 0) { + MP_RETURN_IF_ERROR(result.SetMessage(message)); + if (entry.index < result.mutable_field_values()->size()) { + return std::pair(result, entry.index); + } + } + if (entry.map_id >= 0) { + FieldAccess access(entry.map_id, field_type); + MP_RETURN_IF_ERROR(access.SetMessage(message)); + auto& field_values = *access.mutable_field_values(); + for (int index = 0; index < field_values.size(); ++index) { + FieldAccess key(entry.key_id, entry.key_type); + MP_RETURN_IF_ERROR(key.SetMessage(field_values[index])); + if (key.mutable_field_values()->at(0) == entry.key_value) { + return std::pair(std::move(access), index); + } + } + } + if (entry.field_id >= 0) { + return std::pair(result, entry.index); + } + return absl::InvalidArgumentError(absl::StrCat( + "ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ", + entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type)); +} + +} // namespace + // Replaces a range of field values for one field nested within a protobuf. absl::Status ProtoUtilLite::ReplaceFieldRange( FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, const std::vector& field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(*message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, @@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange( absl::Status ProtoUtilLite::GetFieldRange( const FieldValue& message, ProtoPath proto_path, int length, FieldType field_type, std::vector* field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR( GetFieldRange(v[index], proto_path, length, field_type, field_values)); } else { + if (length == -1) { + length = v.size() - index; + } RET_CHECK_NO_LOG(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size()); field_values->insert(field_values->begin(), v.begin() + index, @@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message, ProtoPath proto_path, FieldType field_type, int* field_count) { - int field_id, index; - std::tie(field_id, index) = proto_path.back(); - proto_path.pop_back(); - std::vector parent; - if (proto_path.empty()) { - parent.push_back(std::string(message)); + ProtoPathEntry entry = proto_path.front(); + proto_path.erase(proto_path.begin()); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); + if (!proto_path.empty()) { + RET_CHECK_NO_LOG(index >= 0 && index < v.size()); + MP_RETURN_IF_ERROR( + GetFieldCount(v[index], proto_path, field_type, field_count)); } else { - MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( - message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); + *field_count = v.size(); } - FieldAccess access(field_id, field_type); - MP_RETURN_IF_ERROR(access.SetMessage(parent[0])); - *field_count = access.mutable_field_values()->size(); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index 7d3a263f3..d71ceac83 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -34,15 +34,31 @@ class ProtoUtilLite { // Defines field types and tag formats. using WireFormatLite = proto_ns::internal::WireFormatLite; - // Defines a sequence of nested field-number field-index pairs. - using ProtoPath = std::vector>; - // The serialized value for a protobuf field. using FieldValue = std::string; // The serialized data type for a protobuf field. using FieldType = WireFormatLite::FieldType; + // A field-id and index, or a map-id and key, or both. + struct ProtoPathEntry { + ProtoPathEntry(int id, int index) : field_id(id), index(index) {} + ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value) + : map_id(id), + key_id(key_id), + key_type(key_type), + key_value(std::move(key_value)) {} + int field_id = -1; + int index = -1; + int map_id = -1; + int key_id = -1; + FieldType key_type; + FieldValue key_value; + }; + + // Defines a sequence of nested field-number field-index pairs. + using ProtoPath = std::vector; + class FieldAccess { public: // Provides access to a certain protobuf field. @@ -57,9 +73,11 @@ class ProtoUtilLite { // Returns the serialized values of the protobuf field. std::vector* mutable_field_values(); + uint32 field_id() const { return field_id_; } + private: - const uint32 field_id_; - const FieldType field_type_; + uint32 field_id_; + FieldType field_type_; std::string message_; std::vector field_values_; }; diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index 034e1a026..a91ea5adc 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -22,6 +22,7 @@ #include #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite; using FieldValue = ProtoUtilLite::FieldValue; using FieldType = ProtoUtilLite::FieldType; using ProtoPath = ProtoUtilLite::ProtoPath; +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; namespace { @@ -84,26 +86,87 @@ std::unique_ptr CloneMessage(const MessageLite& message) { return result; } -// Returns the (tag, index) pairs in a field path. -// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". -absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { - absl::Status status; - std::vector ids = absl::StrSplit(path, '/'); - for (const std::string& id : ids) { - if (id.length() > 0) { - std::pair id_pair = - absl::StrSplit(id, absl::ByAnyChar("[]")); - int tag = 0; - int index = 0; - bool ok = absl::SimpleAtoi(id_pair.first, &tag) && - absl::SimpleAtoi(id_pair.second, &index); - if (!ok) { - status.Update(absl::InvalidArgumentError(path)); - } - result->push_back(std::make_pair(tag, index)); +// Parses one ProtoPathEntry. +// The parsed entry is appended to `result` and removed from `path`. +// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes +// to serialize the key text to protobuf wire format. +absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) { + bool ok = true; + int sb = path.find('['); + int eb = path.find(']'); + int field_id = -1; + ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id); + auto selector = path.substr(sb + 1, eb - 1 - sb); + if (absl::StartsWith(selector, "@")) { + int eq = selector.find('='); + int key_id = -1; + ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id); + auto key_text = selector.substr(eq + 1); + FieldType key_type = FieldType::TYPE_STRING; + result->push_back({field_id, key_id, key_type, std::string(key_text)}); + } else { + int index = 0; + ok &= absl::SimpleAtoi(selector, &index); + result->push_back({field_id, index}); + } + int end = path.find('/', eb); + if (end == std::string::npos) { + path = ""; + } else { + path = path.substr(end + 1); + } + return ok ? absl::OkStatus() + : absl::InvalidArgumentError( + absl::StrCat("Failed to parse ProtoPath entry: ", path)); +} + +// Specifies the FieldTypes for protobuf map keys in a ProtoPath. +// Each ProtoPathEntry::key_value is converted from text to the protobuf +// wire format for its key type. +absl::Status SetMapKeyTypes(const std::vector& key_types, + ProtoPath* result) { + int i = 0; + for (ProtoPathEntry& entry : *result) { + if (entry.map_id >= 0) { + FieldType key_type = key_types[i++]; + std::vector key_value; + MP_RETURN_IF_ERROR( + ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value)); + entry.key_type = key_type; + entry.key_value = key_value.front(); } } - return status; + return absl::OkStatus(); +} + +// Returns the (tag, index) pairs in a field path. +// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]", +// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]". +absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { + result->clear(); + absl::string_view rest = path; + if (absl::StartsWith(rest, "/")) { + rest = rest.substr(1); + } + while (!rest.empty()) { + MP_RETURN_IF_ERROR(ParseEntry(rest, result)); + } + return absl::OkStatus(); +} + +// Parse the TemplateExpression.path field into a ProtoPath struct. +absl::Status ParseProtoPath(const TemplateExpression& rule, + std::string base_path, ProtoPath* result) { + ProtoPath base_entries; + MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries)); + MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result)); + std::vector key_types; + for (int type : rule.key_type()) { + key_types.push_back(static_cast(type)); + } + MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result)); + result->erase(result->begin(), result->begin() + base_entries.size()); + return absl::OkStatus(); } // Returns true if one proto path is prefix by another. @@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) { return absl::StartsWith(path, prefix); } -// Returns the part of one proto path after a prefix proto path. -std::string ProtoPathRelative(const std::string& field_path, - const std::string& base_path) { - CHECK(ProtoPathStartsWith(field_path, base_path)); - return field_path.substr(base_path.length()); -} - // Returns the target ProtoUtilLite::FieldType of a rule. FieldType GetFieldType(const TemplateExpression& rule) { return static_cast(rule.field_type()); @@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) { // Returns the count of field values at a ProtoPath. int FieldCount(const FieldValue& base, ProtoPath field_path, FieldType field_type) { - int field_id, index; - std::tie(field_id, index) = field_path.back(); - field_path.pop_back(); - std::vector parent; - if (field_path.empty()) { - parent.push_back(base); - } else { - MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange( - base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); - } - ProtoUtilLite::FieldAccess access(field_id, field_type); - MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0])); - return access.mutable_field_values()->size(); + int result = 0; + CHECK( + ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok()); + return result; } } // namespace @@ -229,9 +276,7 @@ class TemplateExpanderImpl { return absl::OkStatus(); } ProtoPath field_path; - absl::Status status = - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path); - if (!status.ok()) return status; + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); return ProtoUtilLite::GetFieldRange(output, field_path, 1, GetFieldType(rule), base); } @@ -242,12 +287,13 @@ class TemplateExpanderImpl { const std::vector& field_values, FieldValue* output) { if (!rule.has_path()) { - *output = field_values[0]; + if (!field_values.empty()) { + *output = field_values[0]; + } return absl::OkStatus(); } ProtoPath field_path; - RET_CHECK_OK( - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path)); + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); int field_count = 1; if (rule.has_field_value()) { // For a non-repeated field, only one value can be specified. @@ -257,7 +303,7 @@ class TemplateExpanderImpl { "Multiple values specified for non-repeated field: ", rule.path())); } // For a non-repeated field, the field value is stored only in the rule. - field_path[field_path.size() - 1].second = 0; + field_path[field_path.size() - 1].index = 0; field_count = 0; } return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count, diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 1d81e7a78..5a0ceccd3 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -26,6 +26,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/proto_descriptor.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" @@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message; using mediapipe::proto_ns::OneofDescriptor; using mediapipe::proto_ns::Reflection; using mediapipe::proto_ns::TextFormat; +using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath; +using FieldType = mediapipe::tool::ProtoUtilLite::FieldType; +using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue; namespace mediapipe { @@ -1357,7 +1361,7 @@ absl::Status ProtoPathSplit(const std::string& path, if (!ok) { status.Update(absl::InvalidArgumentError(path)); } - result->push_back(std::make_pair(tag, index)); + result->push_back({tag, index}); } } return status; @@ -1381,7 +1385,7 @@ void StowFieldValue(Message* message, TemplateExpression* expression) { const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); - int field_number = path[path.size() - 1].first; + int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); if (!field->is_repeated()) { std::vector field_values; @@ -1402,6 +1406,124 @@ static void StripQuotes(std::string* str) { } } +// Returns the field or extension for field number. +const FieldDescriptor* FindFieldByNumber(const Message* message, + int field_num) { + const FieldDescriptor* result = + message->GetDescriptor()->FindFieldByNumber(field_num); + if (result == nullptr) { + result = message->GetReflection()->FindKnownExtensionByNumber(field_num); + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the protobuf map key types from a ProtoPath. +std::vector ProtoPathKeyTypes(ProtoPath path) { + std::vector result; + for (auto& entry : path) { + if (entry.map_id >= 0) { + result.push_back(entry.key_type); + } + } + return result; +} + +// Returns the text value for a string or numeric protobuf map key. +std::string GetMapKey(const Message& map_entry) { + auto key_field = map_entry.GetDescriptor()->FindFieldByName("key"); + auto reflection = map_entry.GetReflection(); + if (key_field->type() == FieldDescriptor::TYPE_STRING) { + return reflection->GetString(map_entry, key_field); + } else if (key_field->type() == FieldDescriptor::TYPE_INT32) { + return absl::StrCat(reflection->GetInt32(map_entry, key_field)); + } else if (key_field->type() == FieldDescriptor::TYPE_INT64) { + return absl::StrCat(reflection->GetInt64(map_entry, key_field)); + } + return ""; +} + +// Adjusts map-entries from indexes to keys. +// Protobuf map-entry order is intentionally not preserved. +mediapipe::Status KeyProtoMapEntries(Message* source) { + // Copy the rules from the source CalculatorGraphTemplate. + mediapipe::CalculatorGraphTemplate rules; + rules.ParsePartialFromString(source->SerializePartialAsString()); + // Only the "source" Message knows all extension types. + Message* config_0 = source->GetReflection()->MutableMessage( + source, source->GetDescriptor()->FindFieldByName("config"), nullptr); + for (int i = 0; i < rules.rule().size(); ++i) { + TemplateExpression* rule = rules.mutable_rule()->Mutable(i); + const Message* message = config_0; + ProtoPath path; + MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); + for (int j = 0; j < path.size(); ++j) { + int field_id = path[j].field_id; + int field_index = path[j].index; + const FieldDescriptor* field = FindFieldByNumber(message, field_id); + if (field->is_map()) { + const Message* map_entry = + GetFieldMessage(*message, field, path[j].index); + int key_id = + map_entry->GetDescriptor()->FindFieldByName("key")->number(); + FieldType key_type = static_cast( + map_entry->GetDescriptor()->FindFieldByName("key")->type()); + std::string key_value = GetMapKey(*map_entry); + path[j] = {field_id, key_id, key_type, key_value}; + } + message = GetFieldMessage(*message, field, field_index); + if (!message) { + break; + } + } + if (!rule->path().empty()) { + *rule->mutable_path() = ProtoPathJoin(path); + for (FieldType key_type : ProtoPathKeyTypes(path)) { + *rule->mutable_key_type()->Add() = key_type; + } + } + } + // Copy the rules back into the source CalculatorGraphTemplate. + auto source_rules = + source->GetReflection()->GetMutableRepeatedFieldRef( + source, source->GetDescriptor()->FindFieldByName("rule")); + source_rules.Clear(); + for (auto& rule : rules.rule()) { + source_rules.Add(rule); + } + return absl::OkStatus(); +} + } // namespace class TemplateParser::Parser::MediaPipeParserImpl @@ -1416,6 +1538,8 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); + // Replace map-entry indexes with map keys. + success &= KeyProtoMapEntries(output).ok(); return success; } diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index f9aab7b72..8300181b5 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -17,6 +17,7 @@ load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_simple_subgraph", ) +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) @@ -58,3 +59,12 @@ mediapipe_simple_subgraph( "//mediapipe/framework:test_calculators", ], ) + +mediapipe_proto_library( + name = "frozen_generator_proto", + srcs = ["frozen_generator.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [ + "//mediapipe/framework:packet_generator_proto", + ], +) diff --git a/mediapipe/framework/tool/testdata/frozen_generator.proto b/mediapipe/framework/tool/testdata/frozen_generator.proto new file mode 100644 index 000000000..5f133f461 --- /dev/null +++ b/mediapipe/framework/tool/testdata/frozen_generator.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; + +message FrozenGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional FrozenGeneratorOptions ext = 225748738; + } + + // Path to file containing serialized proto of type tensorflow::GraphDef. + optional string graph_proto_path = 1; + + // This map defines the which streams are fed to which tensors in the model. + map tag_to_tensor_names = 2; + + // Graph nodes to run to initialize the model. + repeated string initialization_op_names = 4; +} From 37d2e369605e87cf741220db1d3c1b4afb403def Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 12 Dec 2022 12:08:45 -0800 Subject: [PATCH 216/469] Internal change PiperOrigin-RevId: 494791433 --- .github/ISSUE_TEMPLATE/14-studio-issue.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/14-studio-issue.md diff --git a/.github/ISSUE_TEMPLATE/14-studio-issue.md b/.github/ISSUE_TEMPLATE/14-studio-issue.md new file mode 100644 index 000000000..5942b1eb1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/14-studio-issue.md @@ -0,0 +1,19 @@ +--- +name: "Studio Issue" +about: Use this template for assistance with the MediaPipe Studio application. +labels: type:support + +--- +Please make sure that this is a MediaPipe Studio issue. + +**System information** (Please provide as much relevant information as possible) +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- Browser and Version +- Any microphone or camera hardware +- URL that shows the problem + +**Describe the expected behavior:** + +**Other info / Complete Logs :** +Include any js console logs that would be helpful to diagnose the problem. +Large logs and files should be attached: From 3f66dde8fdb459be8552b837e83fb2a79c44566c Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 12 Dec 2022 17:33:08 -0800 Subject: [PATCH 217/469] Change `--site_path` default value to match the actual path. This did not match the URL we ended up using for MediaPipe, so needs to be set correctly in order to generate docs that match the real site. This change sets the default to be correct. PiperOrigin-RevId: 494874789 --- docs/build_py_api_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index 46546012d..02eb04074 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -44,14 +44,14 @@ _OUTPUT_DIR = flags.DEFINE_string( _URL_PREFIX = flags.DEFINE_string( 'code_url_prefix', - 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'https://github.com/google/mediapipe/blob/master/mediapipe', 'The url prefix for links to code.') _SEARCH_HINTS = flags.DEFINE_bool( 'search_hints', True, 'Include metadata search hints in the generated files') -_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api/solutions/python', 'Path prefix in the _toc.yaml') From fb2179761187f5a0c73c973d94690685170d9a21 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 12 Dec 2022 21:28:35 -0800 Subject: [PATCH 218/469] Internal change PiperOrigin-RevId: 494914168 --- mediapipe/calculators/image/image_cropping_calculator.cc | 3 ++- mediapipe/calculators/image/image_cropping_calculator_test.cc | 4 ++-- mediapipe/calculators/util/detections_to_rects_calculator.cc | 3 +++ .../calculators/util/detections_to_rects_calculator_test.cc | 3 +++ mediapipe/calculators/util/landmark_projection_calculator.cc | 2 ++ mediapipe/calculators/util/landmarks_smoothing_calculator.cc | 2 ++ mediapipe/calculators/util/rect_projection_calculator.cc | 2 ++ mediapipe/calculators/util/rect_to_render_data_calculator.cc | 3 +++ mediapipe/calculators/util/rect_to_render_scale_calculator.cc | 2 ++ mediapipe/calculators/util/rect_transformation_calculator.cc | 3 +++ .../calculators/util/world_landmark_projection_calculator.cc | 2 ++ .../calculators/video/tracked_detection_manager_calculator.cc | 2 ++ .../calculators/hand_landmarks_to_rect_calculator.cc | 2 ++ .../holistic_landmark/calculators/roi_tracking_calculator.cc | 2 ++ .../calculators/frame_annotation_to_rect_calculator.cc | 2 ++ .../cc/components/processors/image_preprocessing_graph.cc | 1 + .../calculators/landmarks_to_matrix_calculator.cc | 2 ++ .../calculators/landmarks_to_matrix_calculator_test.cc | 2 ++ .../tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc | 2 ++ .../cc/vision/gesture_recognizer/gesture_recognizer_graph.cc | 1 + .../gesture_recognizer/hand_gesture_recognizer_graph.cc | 1 + .../tasks/cc/vision/hand_detector/hand_detector_graph.cc | 1 + .../tasks/cc/vision/hand_detector/hand_detector_graph_test.cc | 1 + .../calculators/hand_association_calculator.cc | 2 ++ .../calculators/hand_association_calculator_test.cc | 2 ++ .../calculators/hand_landmarks_deduplication_calculator.cc | 1 + mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc | 2 ++ .../tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc | 1 + .../cc/vision/hand_landmarker/hand_landmarker_graph_test.cc | 1 + .../vision/hand_landmarker/hand_landmarks_detector_graph.cc | 1 + .../hand_landmarker/hand_landmarks_detector_graph_test.cc | 1 + .../tasks/cc/vision/image_classifier/image_classifier.cc | 1 + .../cc/vision/image_classifier/image_classifier_graph.cc | 1 + mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc | 1 + .../tasks/cc/vision/image_embedder/image_embedder_graph.cc | 1 + mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc | 1 + .../tasks/cc/vision/image_segmenter/image_segmenter_graph.cc | 1 + mediapipe/tasks/cc/vision/object_detector/object_detector.cc | 1 + .../tasks/cc/vision/object_detector/object_detector_graph.cc | 1 + mediapipe/util/rectangle_util_test.cc | 1 + mediapipe/util/tracking/tracked_detection.cc | 2 ++ mediapipe/util/tracking/tracked_detection_manager.cc | 1 + mediapipe/util/tracking/tracked_detection_test.cc | 2 ++ 43 files changed, 70 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 8c9305ffb..1a2b2e5b0 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -37,7 +37,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; namespace mediapipe { namespace { - +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; #if !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/image/image_cropping_calculator_test.cc b/mediapipe/calculators/image/image_cropping_calculator_test.cc index b3f692889..3c565282b 100644 --- a/mediapipe/calculators/image/image_cropping_calculator_test.cc +++ b/mediapipe/calculators/image/image_cropping_calculator_test.cc @@ -195,11 +195,11 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) { auto cc = absl::make_unique( calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); - mediapipe::Rect rect = ParseTextProtoOrDie( + Rect rect = ParseTextProtoOrDie( R"pb( width: 1 height: 1 x_center: 40 y_center: 40 rotation: 0.5 )pb"); - inputs.Tag(kRectTag).Value() = MakePacket(rect); + inputs.Tag(kRectTag).Value() = MakePacket(rect); RectSpec expectRect = { .width = 1, .height = 1, diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 73a67d322..3e566836c 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -37,6 +37,9 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + constexpr float kMinFloat = std::numeric_limits::lowest(); constexpr float kMaxFloat = std::numeric_limits::max(); diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 6caf792a7..63de60a60 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -39,6 +39,9 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRectTag[] = "RECT"; constexpr char kDetectionTag[] = "DETECTION"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + MATCHER_P4(RectEq, x_center, y_center, width, height, "") { return testing::Value(arg.x_center(), testing::Eq(x_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) && diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index e27edea66..9f276da56 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -24,6 +24,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 6673816e7..7a92cfb7e 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -35,7 +35,9 @@ constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; +using ::mediapipe::NormalizedRect; using mediapipe::OneEuroFilter; +using ::mediapipe::Rect; using mediapipe::RelativeVelocityFilter; void NormalizedLandmarksToLandmarks( diff --git a/mediapipe/calculators/util/rect_projection_calculator.cc b/mediapipe/calculators/util/rect_projection_calculator.cc index dcc6e7391..69b28af87 100644 --- a/mediapipe/calculators/util/rect_projection_calculator.cc +++ b/mediapipe/calculators/util/rect_projection_calculator.cc @@ -23,6 +23,8 @@ namespace { constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT"; +using ::mediapipe::NormalizedRect; + } // namespace // Projects rectangle from reference coordinate system (defined by reference diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 400be277d..bbc08255e 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -29,6 +29,9 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; constexpr char kRectsTag[] = "RECTS"; constexpr char kRenderDataTag[] = "RENDER_DATA"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + RenderAnnotation::Rectangle* NewRect( const RectToRenderDataCalculatorOptions& options, RenderData* render_data) { auto* annotation = render_data->add_render_annotations(); diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index d94615228..6ff6b3d51 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -24,6 +24,8 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRenderScaleTag[] = "RENDER_SCALE"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator to get scale for RenderData primitives. diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 15bb26826..4783cb919 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -28,6 +28,9 @@ constexpr char kRectTag[] = "RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + // Wraps around an angle in radians to within -M_PI and M_PI. inline float NormalizeRadians(float angle) { return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI)); diff --git a/mediapipe/calculators/util/world_landmark_projection_calculator.cc b/mediapipe/calculators/util/world_landmark_projection_calculator.cc index bcd7352a2..e843d63bf 100644 --- a/mediapipe/calculators/util/world_landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/world_landmark_projection_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc index c416fa9b0..48664fead 100644 --- a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc +++ b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc @@ -32,6 +32,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr int kDetectionUpdateTimeOutMS = 5000; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES"; diff --git a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc index 6f2c49d64..638678ff5 100644 --- a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc +++ b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { // NORM_LANDMARKS is either the full set of landmarks for the hand, or diff --git a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc index 0da6cd7f7..49c7b93fb 100644 --- a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc +++ b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc @@ -34,6 +34,8 @@ constexpr char kRecropRectTag[] = "RECROP_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kTrackingRectTag[] = "TRACKING_RECT"; +using ::mediapipe::NormalizedRect; + // TODO: Use rect rotation. // Verifies that Intersection over Union of previous frame rect and current // frame re-crop rect is less than threshold. diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc index 476f8cb54..1fe919c54 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc @@ -34,6 +34,8 @@ namespace { constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION"; constexpr char kOutputNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator that converts FrameAnnotation proto to NormalizedRect. diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index b24b7f0cb..fefc1ec52 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -45,6 +45,7 @@ namespace components { namespace processors { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::Tensor; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 277bb170a..088f97c29 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -35,6 +35,8 @@ limitations under the License. namespace mediapipe { namespace api2 { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index fe6f1162b..a1a44c8d1 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -33,6 +33,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index e7fcf6fd9..01f444742 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -57,6 +57,8 @@ namespace { using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision:: gesture_recognizer::proto::GestureRecognizerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandGestureSubgraphTypeName[] = "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 47d95100b..2d949c410 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -46,6 +46,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index d7e983d81..4db57e85b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -52,6 +52,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index c24548c9b..49958e36b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -50,6 +50,7 @@ namespace hand_detector { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index cbbc0e193..f4e5f8c7d 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -53,6 +53,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc index b6df80588..dffdbdd38 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -27,6 +27,8 @@ limitations under the License. namespace mediapipe::api2 { +using ::mediapipe::NormalizedRect; + // HandAssociationCalculator accepts multiple inputs of vectors of // NormalizedRect. The output is a vector of NormalizedRect that contains // rects from the input vectors that don't overlap with each other. When two diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc index cb3130854..138164209 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -26,6 +26,8 @@ limitations under the License. namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + class HandAssociationCalculatorTest : public testing::Test { protected: HandAssociationCalculatorTest() { diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 266ce223f..d875de98f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -41,6 +41,7 @@ limitations under the License. namespace mediapipe::api2 { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 2b818b2e5..3bb1ee8d8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -46,6 +46,8 @@ namespace { using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: hand_landmarker::proto::HandLandmarkerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandLandmarkerGraphTypeName[] = "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 2c4133eb1..05ad97efe 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -49,6 +49,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index f275486f5..c28df2c05 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -54,6 +54,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 014830ba2..4ea066aab 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -53,6 +53,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index d1e928ce7..f28907d2f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -50,6 +50,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 60f8f7ed4..763e0a320 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -58,6 +58,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 2d0379c66..0adcf842d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -38,6 +38,7 @@ namespace image_classifier { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index e3198090f..494b075a7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -54,6 +54,7 @@ constexpr char kGraphTypeName[] = "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 81ccb5361..95c4ff379 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -34,6 +34,7 @@ namespace image_embedder { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index bbee714c6..7130c72e2 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -44,6 +44,7 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 5531968c1..923cf2937 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -49,6 +49,7 @@ namespace image_segmenter { namespace { using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index e0222dd70..2477f8a44 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -57,6 +57,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index fd95bb1ac..e5af7544d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -52,6 +52,7 @@ namespace vision { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/util/rectangle_util_test.cc b/mediapipe/util/rectangle_util_test.cc index cd1946d45..3bc323f9f 100644 --- a/mediapipe/util/rectangle_util_test.cc +++ b/mediapipe/util/rectangle_util_test.cc @@ -20,6 +20,7 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; using ::testing::FloatNear; class RectangleUtilTest : public testing::Test { diff --git a/mediapipe/util/tracking/tracked_detection.cc b/mediapipe/util/tracking/tracked_detection.cc index 130a87640..80a6981a8 100644 --- a/mediapipe/util/tracking/tracked_detection.cc +++ b/mediapipe/util/tracking/tracked_detection.cc @@ -20,6 +20,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + // Struct for carrying boundary information. struct NormalizedRectBounds { float left, right, top, bottom; diff --git a/mediapipe/util/tracking/tracked_detection_manager.cc b/mediapipe/util/tracking/tracked_detection_manager.cc index 597827f3c..a9e348ceb 100644 --- a/mediapipe/util/tracking/tracked_detection_manager.cc +++ b/mediapipe/util/tracking/tracked_detection_manager.cc @@ -21,6 +21,7 @@ namespace { +using ::mediapipe::NormalizedRect; using mediapipe::TrackedDetection; // Checks if a point is out of view. diff --git a/mediapipe/util/tracking/tracked_detection_test.cc b/mediapipe/util/tracking/tracked_detection_test.cc index 60b9df1b1..13efaab92 100644 --- a/mediapipe/util/tracking/tracked_detection_test.cc +++ b/mediapipe/util/tracking/tracked_detection_test.cc @@ -18,6 +18,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + const float kErrorMargin = 1e-4f; TEST(TrackedDetectionTest, ConstructorWithoutBox) { From 78597c5b37a2ef8f3f005ed55f0a01676a08fb0b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 13 Dec 2022 09:05:19 -0800 Subject: [PATCH 219/469] Internal changes. PiperOrigin-RevId: 495038477 --- mediapipe/framework/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 3cc72b4f1..265ae9c6f 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -21,6 +21,7 @@ licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +# The MediaPipe internal package group. No mediapipe users should be added to this group. package_group( name = "mediapipe_internal", packages = [ From db404b1a8593a8b316cc4930dc1bcc845fc3df62 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Tue, 13 Dec 2022 10:21:07 -0800 Subject: [PATCH 220/469] Internal change PiperOrigin-RevId: 495058817 --- mediapipe/framework/tool/proto_util_lite.h | 7 +- mediapipe/framework/tool/template_parser.cc | 181 +++++++++++++++----- 2 files changed, 144 insertions(+), 44 deletions(-) diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index d71ceac83..15e321eeb 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -48,11 +48,16 @@ class ProtoUtilLite { key_id(key_id), key_type(key_type), key_value(std::move(key_value)) {} + bool operator==(const ProtoPathEntry& o) const { + return field_id == o.field_id && index == o.index && map_id == o.map_id && + key_id == o.key_id && key_type == o.key_type && + key_value == o.key_value; + } int field_id = -1; int index = -1; int map_id = -1; int key_id = -1; - FieldType key_type; + FieldType key_type = FieldType::MAX_FIELD_TYPE; FieldValue key_value; }; diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 5a0ceccd3..cf23f3443 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -1367,26 +1367,132 @@ absl::Status ProtoPathSplit(const std::string& path, return status; } +// Returns a message serialized deterministically. +bool DeterministicallySerialize(const Message& proto, std::string* result) { + proto_ns::io::StringOutputStream stream(result); + proto_ns::io::CodedOutputStream output(&stream); + output.SetSerializationDeterministic(true); + return proto.SerializeToCodedStream(&output); +} + // Serialize one field of a message. void SerializeField(const Message* message, const FieldDescriptor* field, std::vector* result) { ProtoUtilLite::FieldValue message_bytes; - CHECK(message->SerializePartialToString(&message_bytes)); + CHECK(DeterministicallySerialize(*message, &message_bytes)); ProtoUtilLite::FieldAccess access( field->number(), static_cast(field->type())); MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes)); *result = *access.mutable_field_values(); } +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Returns all FieldDescriptors including extensions. +std::vector GetFields(const Message* src) { + std::vector result; + src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(), + &result); + for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) { + result.push_back(src->GetDescriptor()->field(i)); + } + return result; +} + +// Orders map entries in dst to match src. +void OrderMapEntries(const Message* src, Message* dst, + std::set* seen = nullptr) { + std::unique_ptr> seen_owner; + if (!seen) { + seen_owner = std::make_unique>(); + seen = seen_owner.get(); + } + if (seen->count(src) > 0) { + return; + } else { + seen->insert(src); + } + for (auto field : GetFields(src)) { + if (field->is_map()) { + dst->GetReflection()->ClearField(dst, field); + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + const Message& entry = + src->GetReflection()->GetRepeatedMessage(*src, field, j); + dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry); + } + } + if (field->type() == FieldDescriptor::TYPE_MESSAGE) { + if (field->is_repeated()) { + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + OrderMapEntries( + &src->GetReflection()->GetRepeatedMessage(*src, field, j), + dst->GetReflection()->MutableRepeatedMessage(dst, field, j), + seen); + } + } else { + OrderMapEntries(&src->GetReflection()->GetMessage(*src, field), + dst->GetReflection()->MutableMessage(dst, field), seen); + } + } + } +} + +// Copies a Message, keeping map entries in order. +std::unique_ptr CloneMessage(const Message* message) { + std::unique_ptr result(message->New()); + result->CopyFrom(*message); + OrderMapEntries(message, result.get()); + return result; +} + +using MessageMap = std::map>; + // For a non-repeated field, move the most recently parsed field value // into the most recently parsed template expression. -void StowFieldValue(Message* message, TemplateExpression* expression) { +void StowFieldValue(Message* message, TemplateExpression* expression, + MessageMap* stowed_messages) { const Reflection* reflection = message->GetReflection(); const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); + + // Save each stowed message unserialized preserving map entry order. + if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) { + (*stowed_messages)[ProtoPathJoin(path)] = + CloneMessage(GetFieldMessage(*message, field, 0)); + } + if (!field->is_repeated()) { std::vector field_values; SerializeField(message, field, &field_values); @@ -1417,37 +1523,6 @@ const FieldDescriptor* FindFieldByNumber(const Message* message, return result; } -// Returns the message value from a field at an index. -const Message* GetFieldMessage(const Message& message, - const FieldDescriptor* field, int index) { - if (field->type() != FieldDescriptor::TYPE_MESSAGE) { - return nullptr; - } - if (!field->is_repeated()) { - return &message.GetReflection()->GetMessage(message, field); - } - if (index < message.GetReflection()->FieldSize(message, field)) { - return &message.GetReflection()->GetRepeatedMessage(message, field, index); - } - return nullptr; -} - -// Serialize a ProtoPath as a readable string. -// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", -// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". -std::string ProtoPathJoin(ProtoPath path) { - std::string result; - for (ProtoUtilLite::ProtoPathEntry& e : path) { - if (e.field_id >= 0) { - absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); - } else if (e.map_id >= 0) { - absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, - "]"); - } - } - return result; -} - // Returns the protobuf map key types from a ProtoPath. std::vector ProtoPathKeyTypes(ProtoPath path) { std::vector result; @@ -1473,9 +1548,29 @@ std::string GetMapKey(const Message& map_entry) { return ""; } +// Returns a Message store in CalculatorGraphTemplate::field_value. +Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) { + auto it = stowed_messages->find(ProtoPathJoin(proto_path)); + return (it != stowed_messages->end()) ? it->second.get() : nullptr; +} + +const Message* GetNestedMessage(const Message& message, + const FieldDescriptor* field, + ProtoPath proto_path, + MessageMap* stowed_messages) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + const Message* result = FindStowedMessage(stowed_messages, proto_path); + if (!result) { + result = GetFieldMessage(message, field, proto_path.back().index); + } + return result; +} + // Adjusts map-entries from indexes to keys. // Protobuf map-entry order is intentionally not preserved. -mediapipe::Status KeyProtoMapEntries(Message* source) { +absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) { // Copy the rules from the source CalculatorGraphTemplate. mediapipe::CalculatorGraphTemplate rules; rules.ParsePartialFromString(source->SerializePartialAsString()); @@ -1489,11 +1584,14 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); for (int j = 0; j < path.size(); ++j) { int field_id = path[j].field_id; - int field_index = path[j].index; const FieldDescriptor* field = FindFieldByNumber(message, field_id); + ProtoPath prefix = {path.begin(), path.begin() + j + 1}; + message = GetNestedMessage(*message, field, prefix, stowed_messages); + if (!message) { + break; + } if (field->is_map()) { - const Message* map_entry = - GetFieldMessage(*message, field, path[j].index); + const Message* map_entry = message; int key_id = map_entry->GetDescriptor()->FindFieldByName("key")->number(); FieldType key_type = static_cast( @@ -1501,10 +1599,6 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { std::string key_value = GetMapKey(*map_entry); path[j] = {field_id, key_id, key_type, key_value}; } - message = GetFieldMessage(*message, field, field_index); - if (!message) { - break; - } } if (!rule->path().empty()) { *rule->mutable_path() = ProtoPathJoin(path); @@ -1539,7 +1633,7 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); // Replace map-entry indexes with map keys. - success &= KeyProtoMapEntries(output).ok(); + success &= KeyProtoMapEntries(output, &stowed_messages_).ok(); return success; } @@ -1565,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl DO(ConsumeFieldTemplate(message)); } else { DO(ConsumeField(message)); - StowFieldValue(message, expression); + StowFieldValue(message, expression, &stowed_messages_); } DO(ConsumeEndTemplate()); return true; @@ -1776,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl } mediapipe::CalculatorGraphTemplate template_rules_; + std::map> stowed_messages_; }; #undef DO From d5ff060bfa6930b9b6b1826b43ca0434b69050a9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 13 Dec 2022 16:01:06 -0800 Subject: [PATCH 221/469] Internal change PiperOrigin-RevId: 495149484 --- mediapipe/graphs/object_detection_3d/calculators/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/graphs/object_detection_3d/calculators/BUILD b/mediapipe/graphs/object_detection_3d/calculators/BUILD index 783fff187..d4c5c496b 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/BUILD +++ b/mediapipe/graphs/object_detection_3d/calculators/BUILD @@ -22,6 +22,7 @@ package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "gl_animation_overlay_calculator_proto", srcs = ["gl_animation_overlay_calculator.proto"], + def_options_lib = False, visibility = ["//visibility:public"], exports = [ "//mediapipe/gpu:gl_animation_overlay_calculator_proto", From 904a537b027a98c69c744cd4944f06e63c3e882d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 13 Dec 2022 16:08:54 -0800 Subject: [PATCH 222/469] Internal change PiperOrigin-RevId: 495151410 --- mediapipe/framework/BUILD | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 265ae9c6f..872944acd 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -57,12 +57,12 @@ mediapipe_proto_library( srcs = ["calculator.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:mediapipe_options_proto", - "//mediapipe/framework:packet_factory_proto", - "//mediapipe/framework:packet_generator_proto", - "//mediapipe/framework:status_handler_proto", - "//mediapipe/framework:stream_handler_proto", + ":calculator_options_proto", + ":mediapipe_options_proto", + ":packet_factory_proto", + ":packet_generator_proto", + ":status_handler_proto", + ":stream_handler_proto", "@com_google_protobuf//:any_proto", ], ) @@ -79,8 +79,8 @@ mediapipe_proto_library( srcs = ["calculator_contract_test.proto"], visibility = ["//mediapipe/framework:__subpackages__"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -89,8 +89,8 @@ mediapipe_proto_library( srcs = ["calculator_profile.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -126,14 +126,14 @@ mediapipe_proto_library( name = "status_handler_proto", srcs = ["status_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( name = "stream_handler_proto", srcs = ["stream_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( @@ -142,8 +142,8 @@ mediapipe_proto_library( srcs = ["test_calculators.proto"], visibility = [":mediapipe_internal"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -151,7 +151,7 @@ mediapipe_proto_library( name = "thread_pool_executor_proto", srcs = ["thread_pool_executor.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) # It is for pure-native Android builds where the library can't have any dependency on libandroid.so From b9d020cb7d32e936943b963c401cc3aeb9f88407 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Tue, 13 Dec 2022 16:58:12 -0800 Subject: [PATCH 223/469] Internal change PiperOrigin-RevId: 495163109 --- mediapipe/framework/scheduler.cc | 11 ++++++++--- mediapipe/framework/scheduler.h | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mediapipe/framework/scheduler.cc b/mediapipe/framework/scheduler.cc index afef4f383..854c10fd5 100644 --- a/mediapipe/framework/scheduler.cc +++ b/mediapipe/framework/scheduler.cc @@ -117,7 +117,7 @@ void Scheduler::SubmitWaitingTasksOnQueues() { // Note: state_mutex_ is held when this function is entered or // exited. void Scheduler::HandleIdle() { - if (handling_idle_) { + if (++handling_idle_ > 1) { // Someone is already inside this method. // Note: This can happen in the sections below where we unlock the mutex // and make more nodes runnable: the nodes can run and become idle again @@ -127,7 +127,6 @@ void Scheduler::HandleIdle() { VLOG(2) << "HandleIdle: already in progress"; return; } - handling_idle_ = true; while (IsIdle() && (state_ == STATE_RUNNING || state_ == STATE_CANCELLING)) { // Remove active sources that are closed. @@ -165,11 +164,17 @@ void Scheduler::HandleIdle() { } } + // If HandleIdle has been called again, then continue scheduling. + if (handling_idle_ > 1) { + handling_idle_ = 1; + continue; + } + // Nothing left to do. break; } - handling_idle_ = false; + handling_idle_ = 0; } // Note: state_mutex_ is held when this function is entered or exited. diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index dd1572d99..b59467b9f 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -302,7 +302,7 @@ class Scheduler { // - We need it to be reentrant, which Mutex does not support. // - We want simultaneous calls to return immediately instead of waiting, // and Mutex's TryLock is not guaranteed to work. - bool handling_idle_ ABSL_GUARDED_BY(state_mutex_) = false; + int handling_idle_ ABSL_GUARDED_BY(state_mutex_) = 0; // Mutex for the scheduler state and related things. // Note: state_ is declared as atomic so that its getter methods don't need From 6fa0a58529ab60bd93bb622e4c97a0e796bb6276 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 14 Dec 2022 00:34:22 -0800 Subject: [PATCH 224/469] Internal change PiperOrigin-RevId: 495235951 --- .../framework/GraphTextureFrame.java | 47 +++++++++++++++---- .../framework/jni/graph_texture_frame_jni.cc | 7 +++ .../framework/jni/graph_texture_frame_jni.h | 3 ++ 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index efaec34a7..586b5c0a0 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -14,6 +14,10 @@ package com.google.mediapipe.framework; +import com.google.common.flogger.FluentLogger; +import java.util.HashSet; +import java.util.Set; + /** * A {@link TextureFrame} that represents a texture produced by MediaPipe. * @@ -21,6 +25,7 @@ package com.google.mediapipe.framework; * method. */ public class GraphTextureFrame implements TextureFrame { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private long nativeBufferHandle; // We cache these to be able to get them without a JNI call. private int textureName; @@ -30,6 +35,7 @@ public class GraphTextureFrame implements TextureFrame { // True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait // when calling getTextureName(). private final boolean deferredSync; + private final Set activeConsumerContextHandleSet = new HashSet<>(); GraphTextureFrame(long nativeHandle, long timestamp) { this(nativeHandle, timestamp, false); @@ -54,17 +60,19 @@ public class GraphTextureFrame implements TextureFrame { * condition if release() is called after the if-check for nativeBufferHandle is already passed. */ @Override - public int getTextureName() { + public synchronized int getTextureName() { // Return special texture id 0 if handle is 0 i.e. frame is already released. if (nativeBufferHandle == 0) { return 0; } - // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using - // PacketGetter.getTextureFrameDeferredSync(). - if (deferredSync) { - // Note that, if a CPU wait has already been done, the sync point will have been - // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. - nativeGpuWait(nativeBufferHandle); + if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) { + // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using + // PacketGetter.getTextureFrameDeferredSync(). + if (deferredSync) { + // Note that, if a CPU wait has already been done, the sync point will have been + // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. + nativeGpuWait(nativeBufferHandle); + } } return textureName; } @@ -92,9 +100,14 @@ public class GraphTextureFrame implements TextureFrame { *

The consumer calls this when it is done using the texture. */ @Override - public void release() { - GlSyncToken consumerToken = - new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + public synchronized void release() { + GlSyncToken consumerToken = null; + // Note that this remove should be moved to the other overload of release when b/68808951 is + // addressed. + if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) { + consumerToken = + new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + } release(consumerToken); } @@ -113,12 +126,24 @@ public class GraphTextureFrame implements TextureFrame { long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken(); nativeReleaseBuffer(nativeBufferHandle, token); nativeBufferHandle = 0; + } else if (consumerSyncToken != null) { + logger.atWarning().log("release with sync token, but handle is 0"); } if (consumerSyncToken != null) { consumerSyncToken.release(); } } + @Override + protected void finalize() throws Throwable { + if (nativeBufferHandle != 0) { + logger.atWarning().log("release was not called before finalize"); + } + if (!activeConsumerContextHandleSet.isEmpty()) { + logger.atWarning().log("active consumers did not release with sync before finalize"); + } + } + private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken); private native int nativeGetTextureName(long nativeHandle); @@ -128,4 +153,6 @@ public class GraphTextureFrame implements TextureFrame { private native void nativeGpuWait(long nativeHandle); private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle); + + private native long nativeGetCurrentExternalContextHandle(); } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index 84df89260..963ea522e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -15,6 +15,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h" #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" @@ -84,3 +85,9 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( } return reinterpret_cast(token); } + +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz) { + return reinterpret_cast( + mediapipe::GlContext::GetCurrentNativeContext()); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h index 45637bb31..02903c664 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h @@ -44,6 +44,9 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz, jlong nativeHandle); +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus From db6ea38cf69a72149e9b8e5e8868c6e3f33a4ac8 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 14 Dec 2022 00:37:52 -0800 Subject: [PATCH 225/469] Internal change PiperOrigin-RevId: 495236576 --- .../framework/GraphTextureFrame.java | 42 +++++++++++++++---- .../mediapipe/framework/TextureFrame.java | 14 +++++++ .../framework/jni/graph_texture_frame_jni.cc | 16 ++++--- .../framework/jni/graph_texture_frame_jni.h | 5 ++- 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index 586b5c0a0..63ea7854b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -36,6 +36,7 @@ public class GraphTextureFrame implements TextureFrame { // when calling getTextureName(). private final boolean deferredSync; private final Set activeConsumerContextHandleSet = new HashSet<>(); + private int refCount = 1; GraphTextureFrame(long nativeHandle, long timestamp) { this(nativeHandle, timestamp, false); @@ -94,6 +95,17 @@ public class GraphTextureFrame implements TextureFrame { return timestamp; } + @Override + public boolean supportsRetain() { + return true; + } + + @Override + public synchronized void retain() { + // TODO: check that refCount is > 0 and handle is not 0. + refCount++; + } + /** * Releases a reference to the underlying buffer. * @@ -121,22 +133,32 @@ public class GraphTextureFrame implements TextureFrame { * currently cannot create a GlSyncToken, so they cannot call this method. */ @Override - public void release(GlSyncToken consumerSyncToken) { - if (nativeBufferHandle != 0) { - long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken(); - nativeReleaseBuffer(nativeBufferHandle, token); - nativeBufferHandle = 0; - } else if (consumerSyncToken != null) { - logger.atWarning().log("release with sync token, but handle is 0"); + public synchronized void release(GlSyncToken consumerSyncToken) { + if (nativeBufferHandle == 0) { + if (consumerSyncToken != null) { + logger.atWarning().log("release with sync token, but handle is 0"); + } + return; } + if (consumerSyncToken != null) { + long token = consumerSyncToken.nativeToken(); + nativeDidRead(nativeBufferHandle, token); + // We should remove the token's context from activeConsumerContextHandleSet here, but for now + // we do it in the release(void) overload. consumerSyncToken.release(); } + + refCount--; + if (refCount <= 0) { + nativeReleaseBuffer(nativeBufferHandle); + nativeBufferHandle = 0; + } } @Override protected void finalize() throws Throwable { - if (nativeBufferHandle != 0) { + if (refCount >= 0 || nativeBufferHandle != 0) { logger.atWarning().log("release was not called before finalize"); } if (!activeConsumerContextHandleSet.isEmpty()) { @@ -144,7 +166,7 @@ public class GraphTextureFrame implements TextureFrame { } } - private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken); + private native void nativeReleaseBuffer(long nativeHandle); private native int nativeGetTextureName(long nativeHandle); private native int nativeGetWidth(long nativeHandle); @@ -155,4 +177,6 @@ public class GraphTextureFrame implements TextureFrame { private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle); private native long nativeGetCurrentExternalContextHandle(); + + private native void nativeDidRead(long nativeHandle, long consumerSyncToken); } diff --git a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java index babfd2958..76eaf39df 100644 --- a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java @@ -59,4 +59,18 @@ public interface TextureFrame extends TextureReleaseCallback { */ @Override void release(GlSyncToken syncToken); + + /** + * If this method returns true, this object supports the retain method, and can be used with + * multiple consumers. Call retain for each additional consumer beyond the first; each consumer + * should call release. + */ + default boolean supportsRetain() { + return false; + } + + /** Increments the reference count. Only available with some implementations of TextureFrame. */ + default void retain() { + throw new UnsupportedOperationException(); + } } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index 963ea522e..dd99cccd4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -22,14 +22,9 @@ using mediapipe::GlTextureBufferSharedPtr; JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + JNIEnv* env, jobject thiz, jlong nativeHandle) { GlTextureBufferSharedPtr* buffer = reinterpret_cast(nativeHandle); - if (consumerSyncToken) { - mediapipe::GlSyncToken& token = - *reinterpret_cast(consumerSyncToken); - (*buffer)->DidRead(token); - } delete buffer; } @@ -91,3 +86,12 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( return reinterpret_cast( mediapipe::GlContext::GetCurrentNativeContext()); } + +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + GlTextureBufferSharedPtr* buffer = + reinterpret_cast(nativeHandle); + mediapipe::GlSyncToken& token = + *reinterpret_cast(consumerSyncToken); + (*buffer)->DidRead(token); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h index 02903c664..41c531fff 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h @@ -26,7 +26,7 @@ extern "C" { // Releases a native mediapipe::GpuBuffer. JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + JNIEnv* env, jobject thiz, jlong nativeHandle); JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)( JNIEnv* env, jobject thiz, jlong nativeHandle); @@ -44,6 +44,9 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz, jlong nativeHandle); +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz); From 7efb3bcf81081c822c76bb1d7e4867e5f1f66115 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:13:41 +0530 Subject: [PATCH 226/469] Added iOS task error codes --- mediapipe/tasks/ios/common/BUILD | 26 +++ .../tasks/ios/common/sources/MPPCommon.h | 179 ++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 mediapipe/tasks/ios/common/BUILD create mode 100644 mediapipe/tasks/ios/common/sources/MPPCommon.h diff --git a/mediapipe/tasks/ios/common/BUILD b/mediapipe/tasks/ios/common/BUILD new file mode 100644 index 000000000..0d00c423f --- /dev/null +++ b/mediapipe/tasks/ios/common/BUILD @@ -0,0 +1,26 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommon", + hdrs = [ + "sources/MPPCommon.h", + ], + module_name = "MPPCommon", +) + diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h new file mode 100644 index 000000000..427b4cb75 --- /dev/null +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -0,0 +1,179 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum MPPTasksErrorCode + * This enum specifies error codes for Mediapipe Task Library. + * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray. + */ +typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { + + // Generic error codes. + + // Unspecified error. + MPPTasksErrorCodeError = 1, + // Invalid argument specified. + MPPTasksErrorCodeInvalidArgumentError = 2, + // Invalid FlatBuffer file or buffer specified. + MPPTasksErrorCodeInvalidFlatBufferError = 3, + // Model contains a builtin op that isn't supported by the OpResolver or + // delegates. + MPPTasksErrorCodeUnsupportedBuiltinOp = 4, + // Model contains a custom op that isn't supported by the OpResolver or + // delegates. + MPPTasksErrorCodeUnsupportedCustomOp = 5, + + // File I/O error codes. + + // No such file. + MPPTasksErrorCodeFileNotFoundError = 100, + // Permission issue. + MPPTasksErrorCodeFilePermissionDeniedError, + // I/O error when reading file. + MPPTasksErrorCodeFileReadError, + // I/O error when mmap-ing file. + MPPTasksErrorCodeFileMmapError, + // ZIP I/O error when unpacMPPTasksErrorCodeing the zip file. + MPPTasksErrorCodeFileZipError, + + // TensorFlow Lite metadata error codes. + + // Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer. + MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, + // No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed. + MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, + // ZIP I/O error when unpacMPPTasksErrorCodeing an associated file. + MPPTasksErrorCodeMetadataAssociatedFileZipError, + // Inconsistency error between the metadata and actual TF Lite model. + // E.g.: number of labels and output tensor values differ. + MPPTasksErrorCodeMetadataInconsistencyError, + // Invalid process units specified. + // E.g.: multiple ProcessUnits with the same type for a given tensor. + MPPTasksErrorCodeMetadataInvalidProcessUnitsError, + // Inconsistency error with the number of labels. + // E.g.: label files for different locales have a different number of labels. + MPPTasksErrorCodeMetadataNumLabelsMismatchError, + // Score calibration parameters parsing error. + // E.g.: too many parameters provided in the corresponding associated file. + MPPTasksErrorCodeMetadataMalformedScoreCalibrationError, + // Unexpected number of subgraphs for the current task. + // E.g.: image classification expects a single subgraph. + MPPTasksErrorCodeMetadataInvalidNumSubgraphsError, + // A given tensor requires NormalizationOptions but none were found. + // E.g.: float input tensor requires normalization to preprocess input images. + MPPTasksErrorCodeMetadataMissingNormalizationOptionsError, + // Invalid ContentProperties specified. + // E.g. expected ImageProperties, got BoundingBoxProperties. + MPPTasksErrorCodeMetadataInvalidContentPropertiesError, + // Metadata is mandatory but was not found. + // E.g. current task requires TFLite Model Metadata but none was found. + MPPTasksErrorCodeMetadataNotFoundError, + // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but + // none was found or it was empty. + // E.g. current task requires labels but none were found. + MPPTasksErrorCodeMetadataMissingLabelsError, + // The ProcessingUnit for tokenizer is not correctly configured. + // E.g BertTokenizer doesn't have a valid vocab file associated. + MPPTasksErrorCodeMetadataInvalidTokenizerError, + + // Input tensor(s) error codes. + + // Unexpected number of input tensors for the current task. + // E.g. current task expects a single input tensor. + MPPTasksErrorCodeInvalidNumInputTensorsError = 300, + // Unexpected input tensor dimensions for the current task. + // E.g.: only 4D input tensors supported. + MPPTasksErrorCodeInvalidInputTensorDimensionsError, + // Unexpected input tensor type for the current task. + // E.g.: current task expects a uint8 pixel image as input. + MPPTasksErrorCodeInvalidInputTensorTypeError, + // Unexpected input tensor bytes size. + // E.g.: size in bytes does not correspond to the expected number of pixels. + MPPTasksErrorCodeInvalidInputTensorSizeError, + // No correct input tensor found for the model. + // E.g.: input tensor name is not part of the text model's input tensors. + MPPTasksErrorCodeInputTensorNotFoundError, + + // Output tensor(s) error codes. + + // Unexpected output tensor dimensions for the current task. + // E.g.: only a batch size of 1 is supported. + MPPTasksErrorCodeInvalidOutputTensorDimensionsError = 400, + // Unexpected input tensor type for the current task. + // E.g.: multi-head model with different output tensor types. + MPPTasksErrorCodeInvalidOutputTensorTypeError, + // No correct output tensor found for the model. + // E.g.: output tensor name is not part of the text model's output tensors. + MPPTasksErrorCodeOutputTensorNotFoundError, + // Unexpected number of output tensors for the current task. + // E.g.: current task expects a single output tensor. + MPPTasksErrorCodeInvalidNumOutputTensorsError, + + // Image processing error codes. + + // Unspecified image processing failures. + MPPTasksErrorCodeImageProcessingError = 500, + // Unexpected input or output buffer metadata. + // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. + MPPTasksErrorCodeImageProcessingInvalidArgumentError, + // Image processing operation failures. + // E.g. libyuv rotation failed for an unknown reason. + MPPTasksErrorCodeImageProcessingBackendError, + + // Task runner error codes. + MPPTasksErrorCodeRunnerError = 600, + // Task runner is not initialized. + MPPTasksErrorCodeRunnerInitializationError, + // Task runner is not started successfully. + MPPTasksErrorCodeRunnerFailsToStartError, + // Task runner is not started. + MPPTasksErrorCodeRunnerNotStartedError, + // Task runner API is called in the wrong processing mode. + MPPTasksErrorCodeRunnerApiCalledInWrongModeError, + // Task runner receives/produces invalid MediaPipe packet timestamp. + MPPTasksErrorCodeRunnerInvalidTimestampError, + // Task runner receives unexpected MediaPipe graph input packet. + // E.g. The packet type doesn't match the graph input stream's data type. + MPPTasksErrorCodeRunnerUnexpectedInputError, + // Task runner produces unexpected MediaPipe graph output packet. + // E.g. The number of output packets is not equal to the number of graph + // output streams. + MPPTasksErrorCodeRunnerUnexpectedOutputError, + // Task runner is not closed successfully. + MPPTasksErrorCodeRunnerFailsToCloseError, + // Task runner's model resources cache service is unavailable or the + // targeting model resources bundle is not found. + MPPTasksErrorCodeRunnerModelResourcesCacheServiceError, + + // Task graph error codes. + MPPTasksErrorCodeGraphError = 700, + // Task graph is not implemented. + MPPTasksErrorCodeTaskGraphNotImplementedError, + // Task graph config is invalid. + MPPTasksErrorCodeInvalidTaskGraphConfigError, + + MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, + + /** + * The last error code in TFLSupportErrorCode (for internal use only). + */ + MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, + +} NS_SWIFT_NAME(TasksErrorCode); + +NS_ASSUME_NONNULL_END From e9fb6c28f5d69e07e92ce9f88c234d7e2a0081f3 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:14:02 +0530 Subject: [PATCH 227/469] Added task options --- .../tasks/ios/core/sources/MPPTaskOptions.h | 48 +++++++++++++++++++ .../tasks/ios/core/sources/MPPTaskOptions.m | 36 ++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptions.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptions.m diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h new file mode 100644 index 000000000..0195f3654 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -0,0 +1,48 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskOptions) +@interface MPPTaskOptions : NSObject +/** + * Base options for configuring the Mediapipe task. + */ +@property(nonatomic, copy) MPPBaseOptions *baseOptions; + +/** + * Initializes a new `MPPTaskOptions` with the absolute path to the model file + * stored locally on the device, set to the given the model path. + * + * @discussion The external model file must be a single standalone TFLite file. It could be packed + * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the + * necessary metadata and associated files might result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * + * @return An instance of `MPPTaskOptions` initialized to the given model path. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m new file mode 100644 index 000000000..e45364d55 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -0,0 +1,36 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPTaskOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _baseOptions = [[MPPBaseOptions alloc] init]; + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath { + self = [self init]; + if (self) { + _baseOptions.modelAssetPath = modelPath; + } + return self; +} + +@end From 22bb87d9e0346cfcdc7e4e2d61baef0f7c987912 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:14:11 +0530 Subject: [PATCH 228/469] Added iOS task result --- .../tasks/ios/core/sources/MPPTaskResult.h | 34 +++++++++++++++++++ .../tasks/ios/core/sources/MPPTaskResult.m | 27 +++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.m diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h new file mode 100644 index 000000000..22171a852 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -0,0 +1,34 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskResult) +@interface MPPTaskResult : NSObject +/** + * Base options for configuring the Mediapipe task. + */ +@property(nonatomic, assign, readonly) long timeStamp; + +- (instancetype)initWithTimeStamp:(long)timeStamp; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m new file mode 100644 index 000000000..ad74c009d --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -0,0 +1,27 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +@implementation MPPTaskResult + +- (instancetype)initWithTimeStamp:(long)timeStamp { + self = [self init]; + if (self) { + _timeStamp = timeStamp; + } + return self; +} + +@end From 0aedff06596a7ee43588489e8dd8ad8d2d24a7b2 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:14:49 +0530 Subject: [PATCH 229/469] Added target for task options --- mediapipe/tasks/ios/core/BUILD | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 3f1193e46..cee0fa4eb 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -21,3 +21,12 @@ objc_library( srcs = ["sources/MPPBaseOptions.m"], hdrs = ["sources/MPPBaseOptions.h"], ) + +objc_library( + name = "MPPTaskOptions", + srcs = ["sources/MPPTaskOptions.m"], + hdrs = ["sources/MPPTaskOptions.h"], + deps = [ + ":MPPBaseOptions", + ], +) From c0fed7df3116db8778052b29de6ab906a95083fa Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:15:01 +0530 Subject: [PATCH 230/469] Added target for task result --- mediapipe/tasks/ios/core/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index cee0fa4eb..7b648945e 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -30,3 +30,9 @@ objc_library( ":MPPBaseOptions", ], ) + +objc_library( + name = "MPPTaskResult", + srcs = ["sources/MPPTaskResult.m"], + hdrs = ["sources/MPPTaskResult.h"], +) From 174f2869a335a075764f1364130b7d9529b93a29 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 14 Dec 2022 08:31:49 -0800 Subject: [PATCH 231/469] Internal changes PiperOrigin-RevId: 495322170 --- mediapipe/framework/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 872944acd..0dd694760 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1,4 +1,3 @@ -# # Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); From e9e173f9fa37948bcb9a028f7822c44773a2bbcf Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 14 Dec 2022 18:09:20 -0800 Subject: [PATCH 232/469] Internal change PiperOrigin-RevId: 495468694 --- mediapipe/framework/api2/port.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index e63d3651e..eee542640 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -557,8 +557,8 @@ class OutputSidePacketAccess { if (output_) output_->Set(ToOldPacket(std::move(packet))); } - void Set(const T& payload) { Set(MakePacket(payload)); } - void Set(T&& payload) { Set(MakePacket(std::move(payload))); } + void Set(const T& payload) { Set(api2::MakePacket(payload)); } + void Set(T&& payload) { Set(api2::MakePacket(std::move(payload))); } private: OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {} From d526b20e19339712e12db73a2f07d07a2c919b01 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 14 Dec 2022 19:52:13 -0800 Subject: [PATCH 233/469] Internal change. PiperOrigin-RevId: 495483878 --- .../formats/tensor_hardware_buffer.h | 71 ------ .../tensor_hardware_buffer_cpu_storage.cc | 216 ------------------ ...tensor_hardware_buffer_cpu_storage_test.cc | 76 ------ 3 files changed, 363 deletions(-) delete mode 100644 mediapipe/framework/formats/tensor_hardware_buffer.h delete mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc delete mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc diff --git a/mediapipe/framework/formats/tensor_hardware_buffer.h b/mediapipe/framework/formats/tensor_hardware_buffer.h deleted file mode 100644 index fa0241bde..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ -#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ - -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#include - -#include - -#include "mediapipe/framework/formats/tensor_buffer.h" -#include "mediapipe/framework/formats/tensor_internal.h" -#include "mediapipe/framework/formats/tensor_v2.h" - -namespace mediapipe { - -// Supports: -// - float 16 and 32 bits -// - signed / unsigned integers 8,16,32 bits -class TensorHardwareBufferView; -struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor { - using ViewT = TensorHardwareBufferView; - TensorBufferDescriptor buffer; -}; - -class TensorHardwareBufferView : public Tensor::View { - public: - TENSOR_UNIQUE_VIEW_TYPE_ID(); - ~TensorHardwareBufferView() = default; - - const TensorHardwareBufferViewDescriptor& descriptor() const override { - return descriptor_; - } - AHardwareBuffer* handle() const { return ahwb_handle_; } - - protected: - TensorHardwareBufferView(int access_capability, Tensor::View::Access access, - Tensor::View::State state, - const TensorHardwareBufferViewDescriptor& desc, - AHardwareBuffer* ahwb_handle) - : Tensor::View(kId, access_capability, access, state), - descriptor_(desc), - ahwb_handle_(ahwb_handle) {} - - private: - bool MatchDescriptor( - uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor) const override { - if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor)) - return false; - auto descriptor = - static_cast(base_descriptor); - return descriptor.buffer.format == descriptor_.buffer.format && - descriptor.buffer.size_alignment <= - descriptor_.buffer.size_alignment && - descriptor_.buffer.size_alignment % - descriptor.buffer.size_alignment == - 0; - } - const TensorHardwareBufferViewDescriptor& descriptor_; - AHardwareBuffer* ahwb_handle_ = nullptr; -}; - -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc deleted file mode 100644 index 9c223ce2c..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc +++ /dev/null @@ -1,216 +0,0 @@ -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#include - -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "mediapipe/framework/formats/tensor_backend.h" -#include "mediapipe/framework/formats/tensor_cpu_buffer.h" -#include "mediapipe/framework/formats/tensor_hardware_buffer.h" -#include "mediapipe/framework/formats/tensor_v2.h" -#include "util/task/status_macros.h" - -namespace mediapipe { -namespace { - -class TensorCpuViewImpl : public TensorCpuView { - public: - TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access, - Tensor::View::State state, - const TensorCpuViewDescriptor& descriptor, void* pointer, - AHardwareBuffer* ahwb_handle) - : TensorCpuView(access_capabilities, access, state, descriptor, pointer), - ahwb_handle_(ahwb_handle) {} - ~TensorCpuViewImpl() { - // If handle_ is null then this view is constructed in GetViews with no - // access. - if (ahwb_handle_) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_unlock(ahwb_handle_, nullptr); - } - } - } - - private: - AHardwareBuffer* ahwb_handle_; -}; - -class TensorHardwareBufferViewImpl : public TensorHardwareBufferView { - public: - TensorHardwareBufferViewImpl( - int access_capability, Tensor::View::Access access, - Tensor::View::State state, - const TensorHardwareBufferViewDescriptor& descriptor, - AHardwareBuffer* handle) - : TensorHardwareBufferView(access_capability, access, state, descriptor, - handle) {} - ~TensorHardwareBufferViewImpl() = default; -}; - -class HardwareBufferCpuStorage : public TensorStorage { - public: - ~HardwareBufferCpuStorage() { - if (!ahwb_handle_) return; - if (__builtin_available(android 26, *)) { - AHardwareBuffer_release(ahwb_handle_); - } - } - - static absl::Status CanProvide( - int access_capability, const Tensor::Shape& shape, uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor) { - // TODO: use AHardwareBuffer_isSupported for API >= 29. - static const bool is_ahwb_supported = [] { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_Desc desc = {}; - // Aligned to the largest possible virtual memory page size. - constexpr uint32_t kPageSize = 16384; - desc.width = kPageSize; - desc.height = 1; - desc.layers = 1; - desc.format = AHARDWAREBUFFER_FORMAT_BLOB; - desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | - AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; - AHardwareBuffer* handle; - if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false; - AHardwareBuffer_release(handle); - return true; - } - return false; - }(); - if (!is_ahwb_supported) { - return absl::UnavailableError( - "AHardwareBuffer is not supported on the platform."); - } - - if (view_type_id != TensorCpuView::kId && - view_type_id != TensorHardwareBufferView::kId) { - return absl::InvalidArgumentError( - "A view type is not supported by this storage."); - } - return absl::OkStatus(); - } - - std::vector> GetViews(uint64_t latest_version) { - std::vector> result; - auto update_state = latest_version == version_ - ? Tensor::View::State::kUpToDate - : Tensor::View::State::kOutdated; - if (ahwb_handle_) { - result.push_back( - std::unique_ptr(new TensorHardwareBufferViewImpl( - kAccessCapability, Tensor::View::Access::kNoAccess, update_state, - hw_descriptor_, ahwb_handle_))); - - result.push_back(std::unique_ptr(new TensorCpuViewImpl( - kAccessCapability, Tensor::View::Access::kNoAccess, update_state, - cpu_descriptor_, nullptr, nullptr))); - } - return result; - } - - absl::StatusOr> GetView( - Tensor::View::Access access, const Tensor::Shape& shape, - uint64_t latest_version, uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor, int access_capability) { - MP_RETURN_IF_ERROR( - CanProvide(access_capability, shape, view_type_id, base_descriptor)); - const auto& buffer_descriptor = - view_type_id == TensorHardwareBufferView::kId - ? static_cast( - base_descriptor) - .buffer - : static_cast(base_descriptor) - .buffer; - if (!ahwb_handle_) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_Desc desc = {}; - desc.width = TensorBufferSize(buffer_descriptor, shape); - desc.height = 1; - desc.layers = 1; - desc.format = AHARDWAREBUFFER_FORMAT_BLOB; - // TODO: Use access capabilities to set hints. - desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | - AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; - auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_); - if (error != 0) { - return absl::UnknownError( - absl::StrCat("Error allocating hardware buffer: ", error)); - } - // Fill all possible views to provide it as proto views. - hw_descriptor_.buffer = buffer_descriptor; - cpu_descriptor_.buffer = buffer_descriptor; - } - } - if (buffer_descriptor.format != hw_descriptor_.buffer.format || - buffer_descriptor.size_alignment > - hw_descriptor_.buffer.size_alignment || - hw_descriptor_.buffer.size_alignment % - buffer_descriptor.size_alignment > - 0) { - return absl::AlreadyExistsError( - "A view with different params is already allocated with this " - "storage"); - } - - absl::StatusOr> result; - if (view_type_id == TensorHardwareBufferView::kId) { - result = GetAhwbView(access, shape, base_descriptor); - } else { - result = GetCpuView(access, shape, base_descriptor); - } - if (result.ok()) version_ = latest_version; - return result; - } - - private: - absl::StatusOr> GetAhwbView( - Tensor::View::Access access, const Tensor::Shape& shape, - const Tensor::ViewDescriptor& base_descriptor) { - return std::unique_ptr(new TensorHardwareBufferViewImpl( - kAccessCapability, access, Tensor::View::State::kUpToDate, - hw_descriptor_, ahwb_handle_)); - } - - absl::StatusOr> GetCpuView( - Tensor::View::Access access, const Tensor::Shape& shape, - const Tensor::ViewDescriptor& base_descriptor) { - void* pointer = nullptr; - if (__builtin_available(android 26, *)) { - int error = - AHardwareBuffer_lock(ahwb_handle_, - access == Tensor::View::Access::kWriteOnly - ? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN - : AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, - -1, nullptr, &pointer); - if (error != 0) { - return absl::UnknownError( - absl::StrCat("Error locking hardware buffer: ", error)); - } - } - return std::unique_ptr( - new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly - ? Tensor::View::AccessCapability::kWrite - : Tensor::View::AccessCapability::kRead, - access, Tensor::View::State::kUpToDate, - cpu_descriptor_, pointer, ahwb_handle_)); - } - - static constexpr int kAccessCapability = - Tensor::View::AccessCapability::kRead | - Tensor::View::AccessCapability::kWrite; - TensorHardwareBufferViewDescriptor hw_descriptor_; - AHardwareBuffer* ahwb_handle_ = nullptr; - - TensorCpuViewDescriptor cpu_descriptor_; - uint64_t version_ = 0; -}; -TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage); - -} // namespace -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || - // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc deleted file mode 100644 index 0afa9899f..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc +++ /dev/null @@ -1,76 +0,0 @@ - -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) -#include - -#include - -#include "mediapipe/framework/formats/tensor_cpu_buffer.h" -#include "mediapipe/framework/formats/tensor_hardware_buffer.h" -#include "mediapipe/framework/formats/tensor_v2.h" -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" - -namespace mediapipe { - -namespace { - -class TensorHardwareBufferTest : public ::testing::Test { - public: - TensorHardwareBufferTest() {} - ~TensorHardwareBufferTest() override {} -}; - -TEST_F(TensorHardwareBufferTest, TestFloat32) { - Tensor tensor{Tensor::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorHardwareBufferViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - EXPECT_NE(view->handle(), nullptr); - } - { - const auto& const_tensor = tensor; - MP_ASSERT_OK_AND_ASSIGN( - auto view, - const_tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - EXPECT_NE(view->data(), nullptr); - } -} - -TEST_F(TensorHardwareBufferTest, TestInt8Padding) { - Tensor tensor{Tensor::Shape({1})}; - - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorHardwareBufferViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8, - .size_alignment = 4}})); - EXPECT_NE(view->handle(), nullptr); - } - { - const auto& const_tensor = tensor; - MP_ASSERT_OK_AND_ASSIGN( - auto view, - const_tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); - EXPECT_NE(view->data(), nullptr); - } -} - -} // namespace - -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || - // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From bf91c5240782364792739cef7deabbc60c6db77e Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:21:07 +0530 Subject: [PATCH 234/469] Fixed typos --- mediapipe/tasks/ios/common/sources/MPPCommon.h | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 427b4cb75..b3d715520 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -48,16 +48,16 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { MPPTasksErrorCodeFileReadError, // I/O error when mmap-ing file. MPPTasksErrorCodeFileMmapError, - // ZIP I/O error when unpacMPPTasksErrorCodeing the zip file. + // ZIP I/O error when unpacking the zip file. MPPTasksErrorCodeFileZipError, // TensorFlow Lite metadata error codes. - // Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer. + // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, - // No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed. + // No such associated file within metadata, or file has not been packed. MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, - // ZIP I/O error when unpacMPPTasksErrorCodeing an associated file. + // ZIP I/O error when unpacking an associated file. MPPTasksErrorCodeMetadataAssociatedFileZipError, // Inconsistency error between the metadata and actual TF Lite model. // E.g.: number of labels and output tensor values differ. @@ -167,11 +167,10 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { // Task graph config is invalid. MPPTasksErrorCodeInvalidTaskGraphConfigError, + // The first error code in MPPTasksErrorCode (for internal use only). MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, - /** - * The last error code in TFLSupportErrorCode (for internal use only). - */ + // The last error code in MPPTasksErrorCode (for internal use only). MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, } NS_SWIFT_NAME(TasksErrorCode); From fe7fbc0b38b23a0639d816ffdd3fc64da0734c9b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:21:14 +0530 Subject: [PATCH 235/469] Fixed comment --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index 0195f3654..6a00de6f5 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -22,6 +22,7 @@ NS_ASSUME_NONNULL_BEGIN * this class. */ NS_SWIFT_NAME(TaskOptions) + @interface MPPTaskOptions : NSObject /** * Base options for configuring the Mediapipe task. @@ -32,10 +33,9 @@ NS_SWIFT_NAME(TaskOptions) * Initializes a new `MPPTaskOptions` with the absolute path to the model file * stored locally on the device, set to the given the model path. * - * @discussion The external model file must be a single standalone TFLite file. It could be packed - * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the - * necessary metadata and associated files might result in errors. Check the [documentation] - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. + * @discussion The external model file must be a single standalone TFLite file. It must be packed + * with TFLite Model Metadata[1] and associated files. Failure to provide the + * necessary metadata and associated files will result in errors. * * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. * From 9ab010758421f6e8cea9d840ff597181c67070a8 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:21:22 +0530 Subject: [PATCH 236/469] Added new line --- mediapipe/tasks/ios/core/sources/MPPTaskResult.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index 22171a852..89555fe32 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -21,6 +21,7 @@ NS_ASSUME_NONNULL_BEGIN * this class. */ NS_SWIFT_NAME(TaskResult) + @interface MPPTaskResult : NSObject /** * Base options for configuring the Mediapipe task. From 163b13d7de654bd996e26a3c2f6659ca9d481833 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:23:27 +0530 Subject: [PATCH 237/469] Updated comments --- mediapipe/tasks/ios/core/sources/MPPTaskResult.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index 89555fe32..f1707a767 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -17,14 +17,14 @@ NS_ASSUME_NONNULL_BEGIN /** - * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * MediaPipe Tasks result base class. Any MediaPipe task result class should extend * this class. */ NS_SWIFT_NAME(TaskResult) @interface MPPTaskResult : NSObject /** - * Base options for configuring the Mediapipe task. + * Timestamp that is associated with the task result object. */ @property(nonatomic, assign, readonly) long timeStamp; From 5ab17fe686ab2fd20936f3351f7df6c619ff9684 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:28:50 +0530 Subject: [PATCH 238/469] Removed convenience initializer --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.h | 14 -------------- mediapipe/tasks/ios/core/sources/MPPTaskOptions.m | 8 -------- 2 files changed, 22 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index 6a00de6f5..ee2f7d032 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -29,20 +29,6 @@ NS_SWIFT_NAME(TaskOptions) */ @property(nonatomic, copy) MPPBaseOptions *baseOptions; -/** - * Initializes a new `MPPTaskOptions` with the absolute path to the model file - * stored locally on the device, set to the given the model path. - * - * @discussion The external model file must be a single standalone TFLite file. It must be packed - * with TFLite Model Metadata[1] and associated files. Failure to provide the - * necessary metadata and associated files will result in errors. - * - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. - * - * @return An instance of `MPPTaskOptions` initialized to the given model path. - */ -- (instancetype)initWithModelPath:(NSString *)modelPath; - @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m index e45364d55..e3cf6684a 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -25,12 +25,4 @@ return self; } -- (instancetype)initWithModelPath:(NSString *)modelPath { - self = [self init]; - if (self) { - _baseOptions.modelAssetPath = modelPath; - } - return self; -} - @end From 6db5eabe0b4ec6090f4dc45c241ab24aa0f2d59e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 15 Dec 2022 00:41:27 -0800 Subject: [PATCH 239/469] Internal change PiperOrigin-RevId: 495525736 --- docs/solutions/holistic.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 8c552834e..11589425d 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -259,6 +259,7 @@ mp_holistic = mp.solutions.holistic # For static images: IMAGE_FILES = [] +BG_COLOR = (192, 192, 192) # gray with mp_holistic.Holistic( static_image_mode=True, model_complexity=2, From 675420341fca61cceaa9d6b8054b858c0695bd6e Mon Sep 17 00:00:00 2001 From: Ayush Gupta Date: Thu, 15 Dec 2022 16:06:54 +0530 Subject: [PATCH 240/469] Internal Change --- .github/bot_config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev From 299aa03302d66d1ed449eaf10e01702b633538ac Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 15 Dec 2022 09:20:22 -0800 Subject: [PATCH 241/469] Internal change PiperOrigin-RevId: 495613573 --- .../audioclassifier/AudioClassifier.java | 2 + .../audio/audioembedder/AudioEmbedder.java | 2 + .../com/google/mediapipe/tasks/core/BUILD | 12 +++ .../google/mediapipe/tasks/core/TaskInfo.java | 12 ++- .../mediapipe/tasks/core/TaskRunner.java | 29 +++++- .../core/logging/TasksStatsDummyLogger.java | 78 +++++++++++++++ .../tasks/core/logging/TasksStatsLogger.java | 98 +++++++++++++++++++ .../text/textclassifier/TextClassifier.java | 1 + .../tasks/text/textembedder/TextEmbedder.java | 1 + .../gesturerecognizer/GestureRecognizer.java | 2 + .../vision/handlandmarker/HandLandmarker.java | 2 + .../imageclassifier/ImageClassifier.java | 2 + .../vision/imageembedder/ImageEmbedder.java | 2 + .../vision/objectdetector/ObjectDetector.java | 2 + 14 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index d78685fe3..4e5cd7655 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(AudioClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java index 4bc505d84..077f28ca2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -200,6 +200,8 @@ public final class AudioEmbedder extends BaseAudioTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(AudioEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 31f885267..3eb28d38b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -22,6 +22,7 @@ android_library( ], manifest = "AndroidManifest.xml", deps = [ + ":logging", "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", "//mediapipe/framework:calculator_java_proto_lite", @@ -37,6 +38,17 @@ android_library( ], ) +android_library( + name = "logging", + srcs = glob( + ["logging/*.java"], + ), + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar") mediapipe_tasks_core_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 12f8be8ba..310f5739c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -32,6 +32,12 @@ public abstract class TaskInfo { /** Builder for {@link TaskInfo}. */ @AutoValue.Builder public abstract static class Builder { + /** Sets the MediaPipe task name. */ + public abstract Builder setTaskName(String value); + + /** Sets the MediaPipe task running mode name. */ + public abstract Builder setTaskRunningModeName(String value); + /** Sets the MediaPipe task graph name. */ public abstract Builder setTaskGraphName(String value); @@ -71,6 +77,10 @@ public abstract class TaskInfo { } } + abstract String taskName(); + + abstract String taskRunningModeName(); + abstract String taskGraphName(); abstract T taskOptions(); @@ -82,7 +92,7 @@ public abstract class TaskInfo { abstract Boolean enableFlowLimiting(); public static Builder builder() { - return new AutoValue_TaskInfo.Builder(); + return new AutoValue_TaskInfo.Builder().setTaskName("").setTaskRunningModeName(""); } /* Returns a list of the output stream names without the stream tags. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index e6fc91cf6..1a128c538 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.tasks.core.logging.TasksStatsLogger; +import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable { private final Graph graph; private final ModelResourcesCache modelResourcesCache; private final AndroidPacketCreator packetCreator; + private final TasksStatsLogger statsLogger; private long lastSeenTimestamp = Long.MIN_VALUE; private ErrorListener errorListener; @@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable { Context context, TaskInfo taskInfo, OutputHandler outputHandler) { + TasksStatsLogger statsLogger = + TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName()); AndroidAssetUtil.initializeNativeAssetManager(context); Graph mediapipeGraph = new Graph(); mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig()); @@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable { mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache); mediapipeGraph.addMultiStreamCallback( taskInfo.outputStreamNames(), - outputHandler::run, - /*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges()); + packets -> { + outputHandler.run(packets); + statsLogger.recordInvocationEnd(packets.get(0).getTimestamp()); + }, + /* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges()); mediapipeGraph.startRunningGraph(); // Waits until all calculators are opened and the graph is fully started. mediapipeGraph.waitUntilGraphIdle(); - return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler); + return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger); } /** @@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable { * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. */ public synchronized TaskResult process(Map inputs) { - addPackets(inputs, generateSyntheticTimestamp()); + long syntheticInputTimestamp = generateSyntheticTimestamp(); + // TODO: Support recording GPU input arrival. + statsLogger.recordCpuInputArrival(syntheticInputTimestamp); + addPackets(inputs, syntheticInputTimestamp); graph.waitUntilGraphIdle(); lastSeenTimestamp = outputHandler.getLatestOutputTimestamp(); return outputHandler.retrieveCachedTaskResult(); @@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized TaskResult process(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); graph.waitUntilGraphIdle(); return outputHandler.retrieveCachedTaskResult(); @@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized void send(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); } @@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); } catch (MediaPipeException e) { reportError(e); } @@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable { // Waits until all calculators are opened and the graph is fully restarted. graph.waitUntilGraphIdle(); graphStarted.set(true); + statsLogger.logSessionStart(); } catch (MediaPipeException e) { reportError(e); } @@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); if (modelResourcesCache != null) { modelResourcesCache.release(); } @@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable { private TaskRunner( Graph graph, ModelResourcesCache modelResourcesCache, - OutputHandler outputHandler) { + OutputHandler outputHandler, + TasksStatsLogger statsLogger) { this.outputHandler = outputHandler; this.graph = graph; this.modelResourcesCache = modelResourcesCache; this.packetCreator = new AndroidPacketCreator(graph); + this.statsLogger = statsLogger; graphStarted.set(true); + this.statsLogger.logSessionStart(); } /** Reports error. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java new file mode 100644 index 000000000..c10b5d224 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java @@ -0,0 +1,78 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import android.content.Context; + +/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */ +public class TasksStatsDummyLogger implements TasksStatsLogger { + + /** + * Creates the MediaPipe Tasks stats dummy logger. + * + * @param context a {@link Context}. + * @param taskNameStr the task api name. + * @param taskRunningModeStr the task running mode string representation. + */ + public static TasksStatsDummyLogger create( + Context context, String taskNameStr, String taskRunningModeStr) { + return new TasksStatsDummyLogger(); + } + + private TasksStatsDummyLogger() {} + + /** Logs the start of a MediaPipe Tasks API session. */ + @Override + public void logSessionStart() {} + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordCpuInputArrival(long packetTimestamp) {} + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordGpuInputArrival(long packetTimestamp) {} + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordInvocationEnd(long packetTimestamp) {} + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + @Override + public void logInvocationReport(StatsSnapshot stats) {} + + /** Logs the Tasks API session end event. */ + @Override + public void logSessionEnd() {} + + /** Logs the MediaPipe Tasks API initialization error. */ + @Override + public void logInitError() {} +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java new file mode 100644 index 000000000..c726e7d0d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java @@ -0,0 +1,98 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import com.google.auto.value.AutoValue; + +/** The stats logger interface that defines what MediaPipe Tasks events to log. */ +public interface TasksStatsLogger { + /** Task stats snapshot. */ + @AutoValue + abstract static class StatsSnapshot { + static StatsSnapshot create( + int cpuInputCount, + int gpuInputCount, + int finishedCount, + int droppedCount, + long totalLatencyMs, + long peakLatencyMs, + long elapsedTimeMs) { + return new AutoValue_TasksStatsLogger_StatsSnapshot( + cpuInputCount, + gpuInputCount, + finishedCount, + droppedCount, + totalLatencyMs, + peakLatencyMs, + elapsedTimeMs); + } + + static StatsSnapshot createDefault() { + return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0); + } + + abstract int cpuInputCount(); + + abstract int gpuInputCount(); + + abstract int finishedCount(); + + abstract int droppedCount(); + + abstract long totalLatencyMs(); + + abstract long peakLatencyMs(); + + abstract long elapsedTimeMs(); + } + + /** Logs the start of a MediaPipe Tasks API session. */ + public void logSessionStart(); + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordCpuInputArrival(long packetTimestamp); + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordGpuInputArrival(long packetTimestamp); + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordInvocationEnd(long packetTimestamp); + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + public void logInvocationReport(StatsSnapshot stats); + + /** Logs the Tasks API session end event. */ + public void logSessionEnd(); + + /** Logs the MediaPipe Tasks API initialization error. */ + public void logInitError(); + + // TODO: Logs more error types. +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 0ea91a9f8..edb78a191 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(TextClassifier.class.getSimpleName()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java index 9b464d0e8..28f351d4b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -159,6 +159,7 @@ public final class TextEmbedder implements AutoCloseable { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(TextEmbedder.class.getSimpleName()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index e9e74a067..a933d2f65 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -194,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(GestureRecognizer.class.getSimpleName()) + .setTaskRunningModeName(recognizerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java index a9270d347..1d08ab928 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java @@ -183,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(HandLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 8990f46fd..38482797c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ImageClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java index af053d860..488927257 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -180,6 +180,8 @@ public final class ImageEmbedder extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ImageEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 769b9137f..d706189ee 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ObjectDetector.class.getSimpleName()) + .setTaskRunningModeName(detectorOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) From fd50b6aa2f6d1a8f69163fbda4db763bbd2862f4 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 15 Dec 2022 11:49:06 -0800 Subject: [PATCH 242/469] Add a new python unit test to test creating mediapipe Image from cvmat. PiperOrigin-RevId: 495655719 --- mediapipe/python/image_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index 117d20974..cd9124948 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -28,6 +28,8 @@ import PIL.Image from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image_frame +TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' + Image = image.Image ImageFormat = image_frame.ImageFormat @@ -187,5 +189,26 @@ class ImageTest(absltest.TestCase): gc.collect() self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) + def test_image_create_from_cvmat(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + mat = cv2.imread(image_path).astype(np.uint8) + mat = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB) + rgb_image = Image(image_format=ImageFormat.SRGB, data=mat) + self.assertEqual(rgb_image.width, 720) + self.assertEqual(rgb_image.height, 382) + self.assertEqual(rgb_image.channels, 3) + self.assertEqual(rgb_image.image_format, ImageFormat.SRGB) + self.assertTrue(np.array_equal(mat, rgb_image.numpy_view())) + + def test_image_create_from_file(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + loaded_image = Image.create_from_file(image_path) + self.assertEqual(loaded_image.width, 720) + self.assertEqual(loaded_image.height, 382) + self.assertEqual(loaded_image.channels, 3) + self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) + if __name__ == '__main__': absltest.main() From 62f0034033ca9b4c0106bd7987669ce3604d2571 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Thu, 15 Dec 2022 14:17:23 -0800 Subject: [PATCH 243/469] Internal change PiperOrigin-RevId: 495694817 --- .github/ISSUE_TEMPLATE/13-solution-issue.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/13-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md index 9297edf6b..bf0d613c9 100644 --- a/.github/ISSUE_TEMPLATE/13-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- name: "Solution (legacy) Issue" -about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions) such as "Pose", including inference model usage/training, solution-specific calculators etc. labels: type:support --- From 8d2473c751ca53fd22164a51b66bbd7fa13375f1 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Thu, 15 Dec 2022 15:42:00 -0800 Subject: [PATCH 244/469] Update `Image` docs to improve rendering. The [API docs](https://developers.google.com/mediapipe/api/solutions/python/mp/Image) have a few rendering issues. e.g., the doc generator will turn ``` This block: Anything here ``` Into a table with heading `This block` and `Anything here` as a plain-text cell. In order to render code as code, it needs to be in backticks. They can also be in `>>> code()` format, and we can try to run them ([doctests](https://docs.python.org/3/library/doctest.html)). I'll have a dashboard ready soon that shows areas we can improve. PiperOrigin-RevId: 495715576 --- mediapipe/python/pybind/image.cc | 46 ++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 5d8663143..e5fa24e8c 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -48,16 +48,19 @@ void ImageSubmodule(pybind11::module* module) { become immutable after creation. Creation examples: - import cv2 - cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) - gray_frame = mp.Image( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) - from PIL import Image - pil_img = Image.new('RGB', (60, 30), color = 'red') - image = mp.Image( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ```python + import cv2 + cv_mat = cv2.imread(input_file)[:, :, ::-1] + rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) + gray_frame = mp.Image( + format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + + from PIL import Image + pil_img = Image.new('RGB', (60, 30), color = 'red') + image = mp.Image( + format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ``` The pixel data in an Image can be retrieved as a numpy ndarray by calling `Image.numpy_view()`. The returned numpy ndarray is a reference to the @@ -65,15 +68,18 @@ void ImageSubmodule(pybind11::module* module) { numpy ndarray, it's required to obtain a copy of it. Pixel data retrieval examples: - for channel in range(num_channel): - for col in range(width): - for row in range(height): - print(image[row, col, channel]) - output_ndarray = image.numpy_view() - print(output_ndarray[0, 0, 0]) - copied_ndarray = np.copy(output_ndarray) - copied_ndarray[0,0,0] = 0 + ```python + for channel in range(num_channel): + for col in range(width): + for row in range(height): + print(image[row, col, channel]) + + output_ndarray = image.numpy_view() + print(output_ndarray[0, 0, 0]) + copied_ndarray = np.copy(output_ndarray) + copied_ndarray[0,0,0] = 0 + ``` )doc", py::dynamic_attr()); @@ -156,9 +162,11 @@ void ImageSubmodule(pybind11::module* module) { An unwritable numpy ndarray. Examples: + ``` output_ndarray = image.numpy_view() copied_ndarray = np.copy(output_ndarray) copied_ndarray[0,0,0] = 0 + ``` )doc"); image.def( @@ -191,10 +199,12 @@ void ImageSubmodule(pybind11::module* module) { IndexError: If the index is invalid or out of bounds. Examples: + ``` for channel in range(num_channel): for col in range(width): for row in range(height): print(image[row, col, channel]) + ``` )doc"); image @@ -224,7 +234,9 @@ void ImageSubmodule(pybind11::module* module) { A boolean. Examples: + ``` image.is_aligned(16) + ``` )doc"); image.def_static( From 6bf5648430108154b2c2f24c1e8edb78275deac9 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 15 Dec 2022 17:40:29 -0800 Subject: [PATCH 245/469] Fix the documentation of the constructor of Image and ImageFrame Python classes. PiperOrigin-RevId: 495739875 --- mediapipe/python/pybind/image.cc | 7 ++++--- mediapipe/python/pybind/image_frame.cc | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index e5fa24e8c..1bcca12ff 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -52,14 +52,15 @@ void ImageSubmodule(pybind11::module* module) { ```python import cv2 cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) + rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat) gray_frame = mp.Image( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) from PIL import Image pil_img = Image.new('RGB', (60, 30), color = 'red') image = mp.Image( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) ``` The pixel data in an Image can be retrieved as a numpy ndarray by calling diff --git a/mediapipe/python/pybind/image_frame.cc b/mediapipe/python/pybind/image_frame.cc index a7fc6bfe4..bc7a9753d 100644 --- a/mediapipe/python/pybind/image_frame.cc +++ b/mediapipe/python/pybind/image_frame.cc @@ -83,14 +83,15 @@ void ImageFrameSubmodule(pybind11::module* module) { Creation examples: import cv2 cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.ImageFrame(format=ImageFormat.SRGB, data=cv_mat) + rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat) gray_frame = mp.ImageFrame( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) from PIL import Image pil_img = Image.new('RGB', (60, 30), color = 'red') image_frame = mp.ImageFrame( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling `ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the From 0a1f050f1fbff3b70c351178eab9ff6b94fcb1db Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 15 Dec 2022 17:50:48 -0800 Subject: [PATCH 246/469] Internal change PiperOrigin-RevId: 495741383 --- mediapipe/calculators/audio/BUILD | 4 ++-- mediapipe/calculators/internal/BUILD | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index 555f7543f..4a8f0f598 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "mfcc_mel_calculators_proto", srcs = ["mfcc_mel_calculators.proto"], diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index caade2dc3..8647e3f3f 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) + package(default_visibility = ["//visibility:private"]) proto_library( From d5562241cc50ec34a04f1fb4f4172df7dbe008bf Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 15 Dec 2022 18:32:10 -0800 Subject: [PATCH 247/469] Tensor: Interoperability GPU/Cpu -> Ahwb by transforming the underlying storage into Ahwb with releasing previously Cpu/Gpu resources. PiperOrigin-RevId: 495748104 --- mediapipe/framework/formats/tensor.h | 2 +- mediapipe/framework/formats/tensor_ahwb.cc | 19 ++++++------ .../framework/formats/tensor_ahwb_gpu_test.cc | 30 +++++++++++++++++++ 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 9d3e90b6a..f5a99cde1 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -408,8 +408,8 @@ class Tensor { mutable std::function release_callback_; bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; - // Use Ahwb for other views: OpenGL / CPU buffer. #endif // MEDIAPIPE_TENSOR_USE_AHWB + // Use Ahwb for other views: OpenGL / CPU buffer. static inline bool use_ahwb_ = false; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 3c3ec8b17..363c5efd0 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -212,9 +212,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { CHECK(!(valid_ & kValidOpenGlTexture2d)) << "Tensor conversion between OpenGL texture and AHardwareBuffer is not " "supported."; - CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) - << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on target system."; bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; @@ -315,7 +312,13 @@ void Tensor::MoveCpuOrSsboToAhwb() const { ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); CHECK(error == 0) << "AHardwareBuffer_lock " << error; } - if (valid_ & kValidOpenGlBuffer) { + if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + valid_ &= ~kValidCpu; + } else if (valid_ & kValidOpenGlBuffer) { gl_context_->Run([this, dest]() { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), @@ -326,11 +329,9 @@ void Tensor::MoveCpuOrSsboToAhwb() const { }); opengl_buffer_ = GL_INVALID_INDEX; gl_context_ = nullptr; - } else if (valid_ & kValidCpu) { - std::memcpy(dest, cpu_buffer_, bytes()); - // Free CPU memory because next time AHWB is mapped instead. - free(cpu_buffer_); - cpu_buffer_ = nullptr; + // Reset OpenGL Buffer validness. The OpenGL buffer will be allocated on top + // of the Ahwb at the next request to the OpenGlBufferView. + valid_ &= ~kValidOpenGlBuffer; } else { LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; } diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index 7ccd9c7f5..a6ca00949 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -152,6 +152,36 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { { auto view = tensor.GetAHardwareBufferReadView(); EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { + // Request the GPU view to get the ssbo allocated internally. + // Request Ahwb view then to transform the storage into Ahwb. + Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + { + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); EXPECT_NE(ptr, nullptr); From b45554623af211792ab394459088506593295a4c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 16 Dec 2022 13:39:31 -0800 Subject: [PATCH 248/469] Fix typo in GetVectorItemCalculator doc PiperOrigin-RevId: 495951016 --- mediapipe/calculators/core/get_vector_item_calculator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index 25d90bfe6..ee886b381 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -47,7 +47,7 @@ namespace api2 { // calculator: "Get{SpecificType}VectorItemCalculator" // input_stream: "VECTOR:vector" // input_stream: "INDEX:index" -// input_stream: "ITEM:item" +// output_stream: "ITEM:item" // options { // [mediapipe.GetVectorItemCalculatorOptions.ext] { // item_index: 5 From 7ce4bb72d4da5fae2d52ee85d0d10dae5dd96f31 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 19 Dec 2022 09:00:46 -0800 Subject: [PATCH 249/469] Replace numpy.float with the builtin float type as numpy removes its own float type in v1.24. PiperOrigin-RevId: 496412858 --- mediapipe/python/packet_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/python/packet_test.py b/mediapipe/python/packet_test.py index e1a4c12af..16fc37c87 100644 --- a/mediapipe/python/packet_test.py +++ b/mediapipe/python/packet_test.py @@ -157,7 +157,7 @@ class PacketTest(absltest.TestCase): p.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p), 0.42) self.assertEqual(p.timestamp, 0) - p2 = packet_creator.create_float(np.float(0.42)) + p2 = packet_creator.create_float(float(0.42)) p2.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p2), 0.42) self.assertEqual(p2.timestamp, 0) From 482247697407106cffdc7beb646477443b573557 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 19 Dec 2022 11:05:23 -0800 Subject: [PATCH 250/469] Internal change PiperOrigin-RevId: 496443946 --- mediapipe/tasks/web/core/fileset_resolver.ts | 24 +++++++------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts index d4691243b..9917035a4 100644 --- a/mediapipe/tasks/web/core/fileset_resolver.ts +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -44,22 +44,14 @@ async function isSimdSupported(): Promise { } async function createFileset( - taskName: string, basePath: string = '.'): Promise { - if (await isSimdSupported()) { - return { - wasmLoaderPath: - `${basePath}/${taskName}_wasm_internal.js`, - wasmBinaryPath: - `${basePath}/${taskName}_wasm_internal.wasm`, - }; - } else { - return { - wasmLoaderPath: - `${basePath}/${taskName}_wasm_nosimd_internal.js`, - wasmBinaryPath: - `${basePath}/${taskName}_wasm_nosimd_internal.wasm`, - }; - } + taskName: string, basePath: string = ''): Promise { + const suffix = + await isSimdSupported() ? 'wasm_internal' : 'wasm_nosimd_internal'; + + return { + wasmLoaderPath: `${basePath}/${taskName}_${suffix}.js`, + wasmBinaryPath: `${basePath}/${taskName}_${suffix}.wasm`, + }; } // tslint:disable:class-as-namespace From 3e6cd5d2bf403299886bfdcb77079d92c2d794b5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 11:54:57 -0800 Subject: [PATCH 251/469] Add support for customizing gesture recognizer layers PiperOrigin-RevId: 496456160 --- .../gesture_recognizer/gesture_recognizer.py | 15 +++++++---- .../gesture_recognizer_test.py | 26 +++++++++++++++++++ .../gesture_recognizer/model_options.py | 6 +++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index f297d8640..556d2fcd7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -173,15 +173,20 @@ class GestureRecognizer(classifier.Classifier): batch_size=None, dtype=tf.float32, name='hand_embedding') - - x = tf.keras.layers.BatchNormalization()(inputs) - x = tf.keras.layers.ReLU()(x) + x = inputs dropout_rate = self._model_options.dropout_rate - x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x) + for i, width in enumerate(self._model_options.layer_widths): + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) + x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) outputs = tf.keras.layers.Dense( self._num_classes, activation='softmax', - name='custom_gesture_recognizer')( + name='custom_gesture_recognizer_out')( x) self._model = tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 280fc6a82..08fda4fea 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -60,6 +60,32 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) + @unittest_mock.patch.object( + tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) + def test_gesture_recognizer_model_layer_widths(self, mock_dense): + layer_widths = [64, 32] + model_options = gesture_recognizer.ModelOptions(layer_widths=layer_widths) + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._validation_data, + options=gesture_recognizer_options) + expected_calls = [ + unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}') + for i, w in enumerate(layer_widths) + ] + expected_calls.append( + unittest_mock.call( + len(self._train_data.label_names), + activation='softmax', + name='custom_gesture_recognizer_out')) + self.assertLen(mock_dense.call_args_list, len(expected_calls)) + mock_dense.assert_has_calls(expected_calls) + self._test_accuracy(model) + def test_export_gesture_recognizer_model(self): model_options = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py index 79a84c792..1870437d4 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py @@ -14,6 +14,7 @@ """Configurable model options for gesture recognizer models.""" import dataclasses +from typing import List @dataclasses.dataclass @@ -23,5 +24,10 @@ class GestureRecognizerModelOptions: Attributes: dropout_rate: The fraction of the input units to drop, used in dropout layer. + layer_widths: A list of hidden layer widths for the gesture model. Each + element in the list will create a new hidden layer with the specified + width. The hidden layers are separated with BatchNorm, Dropout, and ReLU. + Defaults to an empty list(no hidden layers). """ dropout_rate: float = 0.05 + layer_widths: List[int] = dataclasses.field(default_factory=list) From ef3fa67bf423e2d1c2ffba2bab01cc1c7b5d2ba5 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Mon, 19 Dec 2022 12:36:07 -0800 Subject: [PATCH 252/469] Automatic selection of the tensor's storage type by recording previously requested views. PiperOrigin-RevId: 496466136 --- mediapipe/framework/formats/BUILD | 6 ++- mediapipe/framework/formats/tensor.cc | 35 +++++------------- mediapipe/framework/formats/tensor.h | 37 ++++++++++++++++--- mediapipe/framework/formats/tensor_ahwb.cc | 15 ++++++++ mediapipe/framework/formats/tensor_internal.h | 10 ++--- 5 files changed, 67 insertions(+), 36 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index fdb698c48..fdd9b8909 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -428,7 +428,10 @@ cc_library( "tensor.cc", "tensor_ahwb.cc", ], - hdrs = ["tensor.h"], + hdrs = [ + "tensor.h", + "tensor_internal.h", + ], copts = select({ "//mediapipe:apple": [ "-x objective-c++", @@ -452,6 +455,7 @@ cc_library( ], }), deps = [ + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index fdafbff5c..3f11d368a 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -246,10 +246,10 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape, return Tensor::OpenGlTexture2dView::Layout::kAligned; } } - // The best performance of a compute shader can be achived with textures' + // The best performance of a compute shader can be achieved with textures' // width multiple of 256. Making minimum fixed width of 256 waste memory for // small tensors. The optimal balance memory-vs-performance is power of 2. - // The texture width and height are choosen to be closer to square. + // The texture width and height are chosen to be closer to square. float power = std::log2(std::sqrt(static_cast(num_pixels))); w = 1 << static_cast(power); int h = (num_pixels + w - 1) / w; @@ -326,7 +326,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { auto lock(absl::make_unique(&view_mutex_)); AllocateOpenGlBuffer(); if (!(valid_ & kValidOpenGlBuffer)) { - // If the call succeds then AHWB -> SSBO are synchronized so any usage of + // If the call succeeds then AHWB -> SSBO are synchronized so any usage of // the SSBO is correct after this call. if (!InsertAhwbToSsboFence()) { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); @@ -348,8 +348,10 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { }; } -Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const { +Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView( + uint64_t source_location_hash) const { auto lock(absl::make_unique(&view_mutex_)); + TrackAhwbUsage(source_location_hash); AllocateOpenGlBuffer(); valid_ = kValidOpenGlBuffer; return {opengl_buffer_, std::move(lock), nullptr}; @@ -385,6 +387,7 @@ void Tensor::Move(Tensor* src) { src->element_type_ = ElementType::kNone; // Mark as invalidated. cpu_buffer_ = src->cpu_buffer_; src->cpu_buffer_ = nullptr; + ahwb_tracking_key_ = src->ahwb_tracking_key_; #if MEDIAPIPE_METAL_ENABLED device_ = src->device_; src->device_ = nil; @@ -589,8 +592,10 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { return {cpu_buffer_, std::move(lock)}; } -Tensor::CpuWriteView Tensor::GetCpuWriteView() const { +Tensor::CpuWriteView Tensor::GetCpuWriteView( + uint64_t source_location_hash) const { auto lock = absl::make_unique(&view_mutex_); + TrackAhwbUsage(source_location_hash); AllocateCpuBuffer(); valid_ = kValidCpu; #ifdef MEDIAPIPE_TENSOR_USE_AHWB @@ -620,24 +625,4 @@ void Tensor::AllocateCpuBuffer() const { } } -void Tensor::SetPreferredStorageType(StorageType type) { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (__builtin_available(android 26, *)) { - use_ahwb_ = type == StorageType::kAhwb; - VLOG(4) << "Tensor: use of AHardwareBuffer is " - << (use_ahwb_ ? "allowed" : "not allowed"); - } -#else - VLOG(4) << "Tensor: use of AHardwareBuffer is not allowed"; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - -Tensor::StorageType Tensor::GetPreferredStorageType() { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - return use_ahwb_ ? StorageType::kAhwb : StorageType::kDefault; -#else - return StorageType::kDefault; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - } // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index f5a99cde1..8a6f02e9d 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,8 +24,9 @@ #include #include -#include "absl/memory/memory.h" +#include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" #if MEDIAPIPE_METAL_ENABLED @@ -48,6 +49,22 @@ #include "mediapipe/gpu/gl_context.h" #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 +#if defined __has_builtin +#if __has_builtin(__builtin_LINE) +#define builtin_LINE __builtin_LINE +#endif +#if __has_builtin(__builtin_FILE) +#define builtin_FILE __builtin_FILE +#endif +#endif + +#ifndef builtin_LINE +#define builtin_LINE() 0 +#endif +#ifndef builtin_FILE +#define builtin_FILE() "" +#endif + namespace mediapipe { // Tensor is a container of multi-dimensional data that supports sharing the @@ -65,7 +82,7 @@ namespace mediapipe { // GLuint buffer = view.buffer(); // Then the buffer can be bound to the GPU command buffer. // ...binding the buffer to the command buffer... -// ...commiting command buffer and releasing the view... +// ...committing command buffer and releasing the view... // // The following request for the CPU view will be blocked until the GPU view is // released and the GPU task is finished. @@ -161,7 +178,9 @@ class Tensor { using CpuReadView = CpuView; CpuReadView GetCpuReadView() const; using CpuWriteView = CpuView; - CpuWriteView GetCpuWriteView() const; + CpuWriteView GetCpuWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #if MEDIAPIPE_METAL_ENABLED // TODO: id vs. MtlBufferView. @@ -305,7 +324,9 @@ class Tensor { // A valid OpenGL context must be bound to the calling thread due to possible // GPU resource allocation. OpenGlBufferView GetOpenGlBufferReadView() const; - OpenGlBufferView GetOpenGlBufferWriteView() const; + OpenGlBufferView GetOpenGlBufferWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 const Shape& shape() const { return shape_; } @@ -410,7 +431,11 @@ class Tensor { void CreateEglSyncAndFd() const; #endif // MEDIAPIPE_TENSOR_USE_AHWB // Use Ahwb for other views: OpenGL / CPU buffer. - static inline bool use_ahwb_ = false; + mutable bool use_ahwb_ = false; + mutable uint64_t ahwb_tracking_key_ = 0; + // TODO: Tracks all unique tensors. Can grow to a large number. LRU + // can be more predicted. + static inline absl::flat_hash_set ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; @@ -419,6 +444,8 @@ class Tensor { void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; void MoveCpuOrSsboToAhwb() const; + // Set current tracking key, set "use ahwb" if the key is already marked. + void TrackAhwbUsage(uint64_t key) const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 363c5efd0..466811be7 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -265,6 +265,10 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { + // Mark current tracking key as Ahwb-use. + ahwb_usage_track_.insert(ahwb_tracking_key_); + use_ahwb_ = true; + if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; @@ -447,6 +451,16 @@ void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { + if (ahwb_tracking_key_ == 0) { + ahwb_tracking_key_ = source_location_hash; + for (int dim : shape_.dims) { + ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); + } + } + use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); +} + #else // MEDIAPIPE_TENSOR_USE_AHWB bool Tensor::AllocateAhwbMapToSsbo() const { return false; } @@ -455,6 +469,7 @@ void Tensor::MoveAhwbStuff(Tensor* src) {} void Tensor::ReleaseAhwbStuff() {} void* Tensor::MapAhwbToCpuRead() const { return nullptr; } void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t key) const {} #endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_internal.h b/mediapipe/framework/formats/tensor_internal.h index 1231a991c..c223c5b1d 100644 --- a/mediapipe/framework/formats/tensor_internal.h +++ b/mediapipe/framework/formats/tensor_internal.h @@ -18,8 +18,6 @@ #include #include -#include "mediapipe/framework/tool/type_util.h" - namespace mediapipe { // Generates unique view id at compile-time using FILE and LINE. @@ -41,10 +39,12 @@ namespace tensor_internal { // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function constexpr uint64_t kFnvPrime = 0x00000100000001B3; constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325; -constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { - return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime); +constexpr uint64_t FnvHash64(uint64_t value1, uint64_t value2) { + return (value2 ^ value1) * kFnvPrime; +} +constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { + return (str[0] == 0) ? hash : FnvHash64(str + 1, FnvHash64(hash, str[0])); } - template struct TypeList { static constexpr std::size_t size{sizeof...(Ts)}; From ea0bebc22608b9bbc2c0173460a418122bea4861 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 14:48:47 -0800 Subject: [PATCH 253/469] Add BGR -> RGB color conversion to ColorConvertCalculator. PiperOrigin-RevId: 496497002 --- .../calculators/image/color_convert_calculator.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index bdac932bb..4781f1ea1 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -38,6 +38,7 @@ void SetColorChannel(int channel, uint8 value, cv::Mat* mat) { constexpr char kRgbaInTag[] = "RGBA_IN"; constexpr char kRgbInTag[] = "RGB_IN"; +constexpr char kBgrInTag[] = "BGR_IN"; constexpr char kBgraInTag[] = "BGRA_IN"; constexpr char kGrayInTag[] = "GRAY_IN"; constexpr char kRgbaOutTag[] = "RGBA_OUT"; @@ -57,6 +58,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB -> RGBA // RGBA -> BGRA // BGRA -> RGBA +// BGR -> RGB // // This calculator only supports a single input stream and output stream at a // time. If more than one input stream or output stream is present, the @@ -69,6 +71,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB_IN: The input video stream (ImageFrame, SRGB). // BGRA_IN: The input video stream (ImageFrame, SBGRA). // GRAY_IN: The input video stream (ImageFrame, GRAY8). +// BGR_IN: The input video stream (ImageFrame, SBGR). // // Output streams: // RGBA_OUT: The output video stream (ImageFrame, SRGBA). @@ -122,6 +125,10 @@ absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kBgraInTag).Set(); } + if (cc->Inputs().HasTag(kBgrInTag)) { + cc->Inputs().Tag(kBgrInTag).Set(); + } + if (cc->Outputs().HasTag(kRgbOutTag)) { cc->Outputs().Tag(kRgbOutTag).Set(); } @@ -194,6 +201,11 @@ absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) { return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA, cv::COLOR_RGBA2BGRA, cc); } + // BGR -> RGB + if (cc->Inputs().HasTag(kBgrInTag) && cc->Outputs().HasTag(kRgbOutTag)) { + return ConvertAndOutput(kBgrInTag, kRgbOutTag, ImageFormat::SRGB, + cv::COLOR_BGR2RGB, cc); + } return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported image format conversion."; From 6842f2c7c6657e7645ecddd26c92504e1b797f84 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 17:12:06 -0800 Subject: [PATCH 254/469] Use the proper namespace for builder test PiperOrigin-RevId: 496526588 --- mediapipe/framework/api2/builder_test.cc | 131 ++++++++++++----------- 1 file changed, 66 insertions(+), 65 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3bf3ec198..361f740c4 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -15,12 +15,17 @@ #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -namespace mediapipe { -namespace api2 { -namespace test { +namespace mediapipe::api2::builder { +namespace { + +using ::mediapipe::api2::test::Bar; +using ::mediapipe::api2::test::FloatAdder; +using ::mediapipe::api2::test::Foo; +using ::mediapipe::api2::test::Foo2; +using ::mediapipe::api2::test::FooBar1; TEST(BuilderTest, BuildGraph) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); auto& bar = graph.AddNode("Bar"); graph.In("IN").SetName("base") >> foo.In("BASE"); @@ -49,22 +54,21 @@ TEST(BuilderTest, BuildGraph) { } TEST(BuilderTest, CopyableSource) { - builder::Graph graph; - builder::Source a = graph[Input("A")]; + Graph graph; + Source a = graph[Input("A")]; a.SetName("a"); - builder::Source b = graph[Input("B")]; + Source b = graph[Input("B")]; b.SetName("b"); - builder::SideSource side_a = graph[SideInput("SIDE_A")]; + SideSource side_a = graph[SideInput("SIDE_A")]; side_a.SetName("side_a"); - builder::SideSource side_b = graph[SideInput("SIDE_B")]; + SideSource side_b = graph[SideInput("SIDE_B")]; side_b.SetName("side_b"); - builder::Destination out = graph[Output("OUT")]; - builder::SideDestination side_out = - graph[SideOutput("SIDE_OUT")]; + Destination out = graph[Output("OUT")]; + SideDestination side_out = graph[SideOutput("SIDE_OUT")]; - builder::Source input = a; + Source input = a; input = b; - builder::SideSource side_input = side_b; + SideSource side_input = side_b; side_input = side_a; input >> out; @@ -83,28 +87,27 @@ TEST(BuilderTest, CopyableSource) { } TEST(BuilderTest, BuildGraphWithFunctions) { - builder::Graph graph; + Graph graph; - builder::Source base = graph[Input("IN")]; + Source base = graph[Input("IN")]; base.SetName("base"); - builder::SideSource side = graph[SideInput("SIDE")]; + SideSource side = graph[SideInput("SIDE")]; side.SetName("side"); - auto foo_fn = [](builder::Source base, builder::SideSource side, - builder::Graph& graph) { + auto foo_fn = [](Source base, SideSource side, Graph& graph) { auto& foo = graph.AddNode("Foo"); base >> foo[Input("BASE")]; side >> foo[SideInput("SIDE")]; return foo[Output("OUT")]; }; - builder::Source foo_out = foo_fn(base, side, graph); + Source foo_out = foo_fn(base, side, graph); - auto bar_fn = [](builder::Source in, builder::Graph& graph) { + auto bar_fn = [](Source in, Graph& graph) { auto& bar = graph.AddNode("Bar"); in >> bar[Input("IN")]; return bar[Output("OUT")]; }; - builder::Source bar_out = bar_fn(foo_out, graph); + Source bar_out = bar_fn(foo_out, graph); bar_out.SetName("out"); bar_out >> graph[Output("OUT")]; @@ -131,7 +134,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) { template void BuildGraphTypedTest() { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode(); auto& bar = graph.AddNode(); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); @@ -161,12 +164,12 @@ void BuildGraphTypedTest() { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } -TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } TEST(BuilderTest, FanOut) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); auto& adder = graph.AddNode("FloatAdder"); graph.In("IN").SetName("base") >> foo.In("BASE"); @@ -194,9 +197,9 @@ TEST(BuilderTest, FanOut) { } TEST(BuilderTest, TypedMultiple) { - builder::Graph graph; - auto& foo = graph.AddNode(); - auto& adder = graph.AddNode(); + Graph graph; + auto& foo = graph.AddNode(); + auto& adder = graph.AddNode(); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0]; foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1]; @@ -222,8 +225,8 @@ TEST(BuilderTest, TypedMultiple) { } TEST(BuilderTest, TypedByPorts) { - builder::Graph graph; - auto& foo = graph.AddNode(); + Graph graph; + auto& foo = graph.AddNode(); auto& adder = graph.AddNode(); graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase]; @@ -251,7 +254,7 @@ TEST(BuilderTest, TypedByPorts) { } TEST(BuilderTest, PacketGenerator) { - builder::Graph graph; + Graph graph; auto& generator = graph.AddPacketGenerator("FloatGenerator"); graph.SideIn("IN") >> generator.SideIn("IN"); generator.SideOut("OUT") >> graph.SideOut("OUT"); @@ -270,7 +273,7 @@ TEST(BuilderTest, PacketGenerator) { } TEST(BuilderTest, EmptyTag) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In("A").SetName("a") >> foo.In("")[0]; graph.In("C").SetName("c") >> foo.In("")[2]; @@ -302,7 +305,7 @@ TEST(BuilderTest, StringLikeTags) { const std::string kB = "B"; constexpr absl::string_view kC = "C"; - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In(kA).SetName("a") >> foo.In(kA); graph.In(kB).SetName("b") >> foo.In(kB); @@ -324,7 +327,7 @@ TEST(BuilderTest, StringLikeTags) { } TEST(BuilderTest, GraphIndexes) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In(0).SetName("a") >> foo.In("")[0]; graph.In(1).SetName("c") >> foo.In("")[2]; @@ -376,28 +379,27 @@ class AnyAndSameTypeCalculator : public NodeIntf { }; TEST(BuilderTest, AnyAndSameTypeHandledProperly) { - builder::Graph graph; - builder::Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; - builder::Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; + Graph graph; + Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; + Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - builder::Source any_type_output = + Source any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; any_type_output.SetName("any_type_output"); - builder::Source same_type_output = + Source same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; same_type_output.SetName("same_type_output"); - builder::Source recursive_same_type_output = + Source recursive_same_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; recursive_same_type_output.SetName("recursive_same_type_output"); - builder::Source same_int_output = - node[AnyAndSameTypeCalculator::kSameIntOutput]; + Source same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; same_int_output.SetName("same_int_output"); - builder::Source recursive_same_int_type_output = + Source recursive_same_int_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; recursive_same_int_type_output.SetName("recursive_same_int_type_output"); @@ -420,13 +422,13 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { } TEST(BuilderTest, AnyTypeCanBeCast) { - builder::Graph graph; - builder::Source any_input = + Graph graph; + Source any_input = graph.In("GRAPH_ANY_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; - builder::Source any_type_output = + Source any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); any_type_output.SetName("any_type_output"); @@ -446,11 +448,11 @@ TEST(BuilderTest, AnyTypeCanBeCast) { } TEST(BuilderTest, MultiPortIsCastToMultiPort) { - builder::Graph graph; - builder::MultiSource any_input = graph.In("ANY_INPUT"); - builder::MultiSource int_input = any_input.Cast(); - builder::MultiDestination any_output = graph.Out("ANY_OUTPUT"); - builder::MultiDestination int_output = any_output.Cast(); + Graph graph; + MultiSource any_input = graph.In("ANY_INPUT"); + MultiSource int_input = any_input.Cast(); + MultiDestination any_output = graph.Out("ANY_OUTPUT"); + MultiDestination int_output = any_output.Cast(); int_input >> int_output; CalculatorGraphConfig expected = @@ -462,11 +464,11 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) { } TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { - builder::Graph graph; - builder::MultiSource any_multi_input = graph.In("ANY_INPUT"); - builder::Source any_input = any_multi_input; - builder::MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); - builder::Destination any_output = any_multi_output; + Graph graph; + MultiSource any_multi_input = graph.In("ANY_INPUT"); + Source any_input = any_multi_input; + MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); + Destination any_output = any_multi_output; any_input >> any_output; CalculatorGraphConfig expected = @@ -478,11 +480,11 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { } TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { - builder::Graph graph; - builder::Source int_input = graph.In("INT_INPUT").Cast(); - builder::Source any_input = graph.In("ANY_OUTPUT"); - builder::Destination int_output = graph.Out("INT_OUTPUT").Cast(); - builder::Destination any_output = graph.Out("ANY_OUTPUT"); + Graph graph; + Source int_input = graph.In("INT_INPUT").Cast(); + Source any_input = graph.In("ANY_OUTPUT"); + Destination int_output = graph.Out("INT_OUTPUT").Cast(); + Destination any_output = graph.Out("ANY_OUTPUT"); int_input >> int_output; any_input >> any_output; @@ -496,6 +498,5 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -} // namespace test -} // namespace api2 -} // namespace mediapipe +} // namespace +} // namespace mediapipe::api2::builder From f5f2fee0b9b2cccdf9fd04c6cb4b96fd8c1bc7ee Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 17:16:04 -0800 Subject: [PATCH 255/469] Switch to Cast where possible and reduce usage of operator[](port). PiperOrigin-RevId: 496527250 --- mediapipe/framework/api2/builder_test.cc | 36 ++++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 361f740c4..d8522b3c8 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -55,16 +55,16 @@ TEST(BuilderTest, BuildGraph) { TEST(BuilderTest, CopyableSource) { Graph graph; - Source a = graph[Input("A")]; + Source a = graph.In("A").Cast(); a.SetName("a"); - Source b = graph[Input("B")]; + Source b = graph.In("B").Cast(); b.SetName("b"); - SideSource side_a = graph[SideInput("SIDE_A")]; + SideSource side_a = graph.SideIn("SIDE_A").Cast(); side_a.SetName("side_a"); - SideSource side_b = graph[SideInput("SIDE_B")]; + SideSource side_b = graph.SideIn("SIDE_B").Cast(); side_b.SetName("side_b"); - Destination out = graph[Output("OUT")]; - SideDestination side_out = graph[SideOutput("SIDE_OUT")]; + Destination out = graph.Out("OUT").Cast(); + SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); Source input = a; input = b; @@ -89,28 +89,28 @@ TEST(BuilderTest, CopyableSource) { TEST(BuilderTest, BuildGraphWithFunctions) { Graph graph; - Source base = graph[Input("IN")]; + Source base = graph.In("IN").Cast(); base.SetName("base"); - SideSource side = graph[SideInput("SIDE")]; + SideSource side = graph.SideIn("SIDE").Cast(); side.SetName("side"); auto foo_fn = [](Source base, SideSource side, Graph& graph) { auto& foo = graph.AddNode("Foo"); - base >> foo[Input("BASE")]; - side >> foo[SideInput("SIDE")]; - return foo[Output("OUT")]; + base >> foo.In("BASE"); + side >> foo.SideIn("SIDE"); + return foo.Out("OUT")[0].Cast(); }; Source foo_out = foo_fn(base, side, graph); auto bar_fn = [](Source in, Graph& graph) { auto& bar = graph.AddNode("Bar"); - in >> bar[Input("IN")]; - return bar[Output("OUT")]; + in >> bar.In("IN"); + return bar.Out("OUT")[0].Cast(); }; Source bar_out = bar_fn(foo_out, graph); bar_out.SetName("out"); - bar_out >> graph[Output("OUT")]; + bar_out >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -229,10 +229,10 @@ TEST(BuilderTest, TypedByPorts) { auto& foo = graph.AddNode(); auto& adder = graph.AddNode(); - graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase]; + graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase]; foo[Foo::kOut] >> adder[FloatAdder::kIn][0]; foo[Foo::kOut] >> adder[FloatAdder::kIn][1]; - adder[FloatAdder::kOut].SetName("out") >> graph[FooBar1::kOut]; + adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -380,8 +380,8 @@ class AnyAndSameTypeCalculator : public NodeIntf { TEST(BuilderTest, AnyAndSameTypeHandledProperly) { Graph graph; - Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; - Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; + Source any_input = graph.In("GRAPH_ANY_INPUT"); + Source int_input = graph.In("GRAPH_INT_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; From 994eb28d2c007ebc09795b300cedf0abe7130507 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 18:05:30 -0800 Subject: [PATCH 256/469] Chain SetName calls where possible PiperOrigin-RevId: 496534328 --- mediapipe/framework/api2/builder_test.cc | 28 ++++++++++-------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index d8522b3c8..b01c2b759 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -55,14 +55,12 @@ TEST(BuilderTest, BuildGraph) { TEST(BuilderTest, CopyableSource) { Graph graph; - Source a = graph.In("A").Cast(); - a.SetName("a"); - Source b = graph.In("B").Cast(); - b.SetName("b"); - SideSource side_a = graph.SideIn("SIDE_A").Cast(); - side_a.SetName("side_a"); - SideSource side_b = graph.SideIn("SIDE_B").Cast(); - side_b.SetName("side_b"); + Source a = graph.In("A").SetName("a").Cast(); + Source b = graph.In("B").SetName("b").Cast(); + SideSource side_a = + graph.SideIn("SIDE_A").SetName("side_a").Cast(); + SideSource side_b = + graph.SideIn("SIDE_B").SetName("side_b").Cast(); Destination out = graph.Out("OUT").Cast(); SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); @@ -89,10 +87,8 @@ TEST(BuilderTest, CopyableSource) { TEST(BuilderTest, BuildGraphWithFunctions) { Graph graph; - Source base = graph.In("IN").Cast(); - base.SetName("base"); - SideSource side = graph.SideIn("SIDE").Cast(); - side.SetName("side"); + Source base = graph.In("IN").SetName("base").Cast(); + SideSource side = graph.SideIn("SIDE").SetName("side").Cast(); auto foo_fn = [](Source base, SideSource side, Graph& graph) { auto& foo = graph.AddNode("Foo"); @@ -108,9 +104,8 @@ TEST(BuilderTest, BuildGraphWithFunctions) { return bar.Out("OUT")[0].Cast(); }; Source bar_out = bar_fn(foo_out, graph); - bar_out.SetName("out"); - bar_out >> graph.Out("OUT"); + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -429,8 +424,9 @@ TEST(BuilderTest, AnyTypeCanBeCast) { auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; Source any_type_output = - node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); - any_type_output.SetName("any_type_output"); + node[AnyAndSameTypeCalculator::kAnyTypeOutput] + .SetName("any_type_output") + .Cast(); any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); From 90678040057ee23dc2ad29c6982010a260c7b7cd Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 19 Dec 2022 19:39:00 -0800 Subject: [PATCH 257/469] Fix the missing logging component issue of mediapipe tasks core. PiperOrigin-RevId: 496548340 --- mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 3eb28d38b..5f7101776 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -53,7 +53,7 @@ load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl" mediapipe_tasks_core_aar( name = "tasks_core", - srcs = glob(["*.java"]) + [ + srcs = glob(["**/*.java"]) + [ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src", From 4682416f0f426e8302b4181a7085713ac1c6e38c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 22:07:55 -0800 Subject: [PATCH 258/469] Internal change PiperOrigin-RevId: 496568835 --- mediapipe/calculators/core/BUILD | 14 ++--- mediapipe/calculators/image/BUILD | 4 +- mediapipe/calculators/util/BUILD | 2 +- mediapipe/framework/BUILD | 70 ++++++++++++------------ mediapipe/framework/formats/BUILD | 24 ++++---- mediapipe/framework/formats/motion/BUILD | 2 +- mediapipe/framework/stream_handler/BUILD | 14 ++--- mediapipe/framework/tool/BUILD | 18 +++--- mediapipe/gpu/BUILD | 14 ++--- 9 files changed, 81 insertions(+), 81 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 2c143a609..b3378a74e 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -567,7 +567,7 @@ cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], deps = [ - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -584,7 +584,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -762,7 +762,7 @@ cc_library( srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], deps = [ - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -786,7 +786,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -852,10 +852,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ + ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - "//mediapipe/calculators/core:counting_source_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -1302,7 +1302,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 530dd3d4a..9aae8cfbc 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -378,8 +378,8 @@ cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], deps = [ + ":scale_image_calculator_cc_proto", ":scale_image_utils", - "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -747,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", + ":image_transformation_calculator", ":warp_affine_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 1529ead8a..a679a80fd 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -22,8 +22,8 @@ cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], deps = [ + ":detections_to_rects_calculator", ":detections_to_rects_calculator_cc_proto", - "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 0dd694760..082ea9994 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -226,13 +226,13 @@ cc_library( ":mediapipe_internal", ], deps = [ + ":calculator_cc_proto", ":graph_service", + ":mediapipe_options_cc_proto", + ":packet_generator_cc_proto", ":packet_type", ":port", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -328,10 +328,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":calculator_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", + ":thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -391,6 +391,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -407,10 +408,9 @@ cc_library( ":packet_set", ":packet_type", ":port", + ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -466,6 +466,7 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ + ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -475,7 +476,6 @@ cc_library( ":packet", ":packet_set", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -583,7 +583,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:mediapipe_options_cc_proto", + ":mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -670,11 +670,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", + ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -784,12 +784,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", + ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -875,10 +875,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -896,13 +896,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", + ":packet_factory_cc_proto", ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", - "//mediapipe/framework:packet_factory_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1019,10 +1019,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ + ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1035,10 +1035,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1097,7 +1097,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1162,22 +1162,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", + ":status_handler_cc_proto", + ":stream_handler_cc_proto", ":subgraph", + ":thread_pool_executor_cc_proto", ":timestamp", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1202,11 +1202,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1233,6 +1233,7 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1242,7 +1243,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1256,11 +1256,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1368,6 +1368,7 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1376,7 +1377,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1403,6 +1403,7 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", + ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1410,13 +1411,12 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", + ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1481,12 +1481,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1630,8 +1630,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index fdd9b8909..f5a043f10 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - deps = ["//mediapipe/framework/formats:location_data_proto"], + deps = [":location_data_proto"], ) mediapipe_register_type( @@ -38,7 +38,7 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = ["//mediapipe/framework/formats:detection_cc_proto"], + deps = [":detection_cc_proto"], ) mediapipe_proto_library( @@ -105,8 +105,8 @@ cc_library( srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ + ":matrix_data_cc_proto", "//mediapipe/framework:port", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -142,7 +142,7 @@ cc_library( srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -166,8 +166,8 @@ cc_library( srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], deps = [ + ":image_format_cc_proto", ":image_frame", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -194,7 +194,7 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", + ":location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -245,7 +245,7 @@ cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", ], ) @@ -263,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ + ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -324,8 +324,8 @@ cc_library( "//conditions:default": [], }), deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:image_frame", + ":image_format_cc_proto", + ":image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -354,7 +354,7 @@ cc_library( hdrs = ["image_multi_pool.h"], deps = [ ":image", - "//mediapipe/framework/formats:image_frame_pool", + ":image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -390,7 +390,7 @@ cc_library( ], deps = [ ":image", - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index f1bbc0289..c9bb8b4ff 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -38,11 +38,11 @@ cc_library( srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], deps = [ + ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", - "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 01ef6ee86..68a9af52d 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -88,8 +88,8 @@ cc_library( srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], deps = [ + ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -110,8 +110,8 @@ cc_library( srcs = ["fixed_size_input_stream_handler.cc"], deps = [ ":default_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -159,13 +159,13 @@ cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], deps = [ + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -177,10 +177,10 @@ cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], deps = [ + ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -243,6 +243,7 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -251,7 +252,6 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -272,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -289,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 89cb802da..193343a90 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,6 +299,7 @@ mediapipe_cc_test( requires_full_emulation = False, deps = [ ":node_chain_subgraph_cc_proto", + ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", @@ -313,7 +314,6 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", - "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -422,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ + ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -485,13 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,6 +506,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", @@ -515,7 +516,6 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -661,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ + ":simulation_clock", "//mediapipe/framework:thread_pool_executor", - "//mediapipe/framework/tool:simulation_clock", ], ) @@ -789,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -805,6 +805,7 @@ cc_library( deps = [ ":container_util", ":options_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -814,7 +815,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -841,6 +841,7 @@ cc_library( ], deps = [ ":container_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -850,7 +851,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -893,6 +893,7 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", + ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -904,7 +905,6 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 009eb3f9e..cc5e50dfc 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -564,6 +564,7 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -571,7 +572,6 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", - "//mediapipe/gpu:gl_context_options_cc_proto", ], ) @@ -592,7 +592,7 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - "//mediapipe/gpu:gl_context_options_cc_proto", + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -833,10 +833,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", + ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -907,8 +907,8 @@ proto_library( srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) @@ -930,6 +930,7 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -937,7 +938,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", - "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -950,13 +950,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -966,8 +966,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) From 8c013647c87cc5784cd545df5f92afd33c6fe941 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 04:47:09 -0800 Subject: [PATCH 259/469] Internal change PiperOrigin-RevId: 496629682 --- mediapipe/calculators/core/BUILD | 14 ++--- mediapipe/calculators/image/BUILD | 4 +- mediapipe/calculators/util/BUILD | 2 +- mediapipe/framework/BUILD | 70 ++++++++++++------------ mediapipe/framework/formats/BUILD | 24 ++++---- mediapipe/framework/formats/motion/BUILD | 2 +- mediapipe/framework/stream_handler/BUILD | 14 ++--- mediapipe/framework/tool/BUILD | 18 +++--- mediapipe/gpu/BUILD | 14 ++--- 9 files changed, 81 insertions(+), 81 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index b3378a74e..2c143a609 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -567,7 +567,7 @@ cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], deps = [ - ":packet_thinner_calculator_cc_proto", + "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -584,7 +584,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - ":packet_thinner_calculator_cc_proto", + "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -762,7 +762,7 @@ cc_library( srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], deps = [ - ":packet_resampler_calculator_cc_proto", + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -786,7 +786,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - ":packet_resampler_calculator_cc_proto", + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -852,10 +852,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ - ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - ":pass_through_calculator", + "//mediapipe/calculators/core:counting_source_calculator", + "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -1302,7 +1302,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - ":pass_through_calculator", + "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 9aae8cfbc..530dd3d4a 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -378,8 +378,8 @@ cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], deps = [ - ":scale_image_calculator_cc_proto", ":scale_image_utils", + "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -747,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", - ":image_transformation_calculator", ":warp_affine_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index a679a80fd..1529ead8a 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -22,8 +22,8 @@ cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], deps = [ - ":detections_to_rects_calculator", ":detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 082ea9994..0dd694760 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -226,13 +226,13 @@ cc_library( ":mediapipe_internal", ], deps = [ - ":calculator_cc_proto", ":graph_service", - ":mediapipe_options_cc_proto", - ":packet_generator_cc_proto", ":packet_type", ":port", - ":status_handler_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -328,10 +328,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - ":calculator_cc_proto", - ":packet_generator_cc_proto", - ":status_handler_cc_proto", - ":thread_pool_executor_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -391,7 +391,6 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", - ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -408,9 +407,10 @@ cc_library( ":packet_set", ":packet_type", ":port", - ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -466,7 +466,6 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ - ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -476,6 +475,7 @@ cc_library( ":packet", ":packet_set", ":port", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -583,7 +583,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - ":mediapipe_options_cc_proto", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -670,11 +670,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", - ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -784,12 +784,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", - ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -875,10 +875,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", - ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", + "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -896,13 +896,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", - ":packet_factory_cc_proto", ":packet_generator", - ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", + "//mediapipe/framework:packet_factory_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1019,10 +1019,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ - ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1035,10 +1035,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ - ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1097,7 +1097,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - ":thread_pool_executor_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1162,22 +1162,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", - ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", - ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", - ":status_handler_cc_proto", - ":stream_handler_cc_proto", ":subgraph", - ":thread_pool_executor_cc_proto", ":timestamp", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework:stream_handler_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1202,11 +1202,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ - ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1233,7 +1233,6 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", - ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1243,6 +1242,7 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1256,11 +1256,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ - ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - ":packet_generator_cc_proto", - ":status_handler_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1368,7 +1368,6 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ - ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1377,6 +1376,7 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1403,7 +1403,6 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", - ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1411,12 +1410,13 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", - ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1481,12 +1481,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ - ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1630,8 +1630,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", - ":packet_generator_cc_proto", ":packet_type", + "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f5a043f10..fdd9b8909 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - deps = [":location_data_proto"], + deps = ["//mediapipe/framework/formats:location_data_proto"], ) mediapipe_register_type( @@ -38,7 +38,7 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = [":detection_cc_proto"], + deps = ["//mediapipe/framework/formats:detection_cc_proto"], ) mediapipe_proto_library( @@ -105,8 +105,8 @@ cc_library( srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ - ":matrix_data_cc_proto", "//mediapipe/framework:port", + "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -142,7 +142,7 @@ cc_library( srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], deps = [ - ":image_format_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -166,8 +166,8 @@ cc_library( srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], deps = [ - ":image_format_cc_proto", ":image_frame", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -194,7 +194,7 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", - ":location_data_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -245,7 +245,7 @@ cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], deps = [ - ":image_format_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", ], ) @@ -263,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ - ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -324,8 +324,8 @@ cc_library( "//conditions:default": [], }), deps = [ - ":image_format_cc_proto", - ":image_frame", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -354,7 +354,7 @@ cc_library( hdrs = ["image_multi_pool.h"], deps = [ ":image", - ":image_frame_pool", + "//mediapipe/framework/formats:image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -390,7 +390,7 @@ cc_library( ], deps = [ ":image", - ":image_format_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index c9bb8b4ff..f1bbc0289 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -38,11 +38,11 @@ cc_library( srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], deps = [ - ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", + "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 68a9af52d..01ef6ee86 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -88,8 +88,8 @@ cc_library( srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], deps = [ - ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -110,8 +110,8 @@ cc_library( srcs = ["fixed_size_input_stream_handler.cc"], deps = [ ":default_input_stream_handler", - ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -159,13 +159,13 @@ cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], deps = [ - ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", + "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -177,10 +177,10 @@ cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], deps = [ - ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", + "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -243,7 +243,6 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", - ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -252,6 +251,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -272,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", - ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -289,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", - ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 193343a90..89cb802da 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,7 +299,6 @@ mediapipe_cc_test( requires_full_emulation = False, deps = [ ":node_chain_subgraph_cc_proto", - ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", @@ -314,6 +313,7 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", + "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -422,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ - ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -485,13 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ - ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,7 +506,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", @@ -516,6 +515,7 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -661,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ - ":simulation_clock", "//mediapipe/framework:thread_pool_executor", + "//mediapipe/framework/tool:simulation_clock", ], ) @@ -789,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", - ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -805,7 +805,6 @@ cc_library( deps = [ ":container_util", ":options_util", - ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -815,6 +814,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -841,7 +841,6 @@ cc_library( ], deps = [ ":container_util", - ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -851,6 +850,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -893,7 +893,6 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", - ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -905,6 +904,7 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index cc5e50dfc..009eb3f9e 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -564,7 +564,6 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ - ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -572,6 +571,7 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", + "//mediapipe/gpu:gl_context_options_cc_proto", ], ) @@ -592,7 +592,7 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - ":gl_context_options_cc_proto", + "//mediapipe/gpu:gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -833,10 +833,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", - ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -907,8 +907,8 @@ proto_library( srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ - ":scale_mode_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:scale_mode_proto", ], ) @@ -930,7 +930,6 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", - ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -938,6 +937,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", + "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -950,13 +950,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", - ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -966,8 +966,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ - ":scale_mode_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:scale_mode_proto", ], ) From e405c2b67d68e6c99fbd7bbf4731ce4a387201f7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 10:59:23 -0800 Subject: [PATCH 260/469] Internal change PiperOrigin-RevId: 496702117 --- .../calculators/image/affine_transformation_runner_gl.cc | 6 +++--- .../tensor/image_to_tensor_converter_gl_texture.cc | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc index c38fc8e07..361dfc902 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_gl.cc +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -92,8 +92,8 @@ class GlTextureWarpAffineRunner constexpr GLchar kVertShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -104,7 +104,7 @@ class GlTextureWarpAffineRunner )"; constexpr GLchar kFragShader[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) in vec2 sample_coordinate; uniform sampler2D input_texture; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 5efd34041..165df8970 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -68,8 +68,8 @@ class GlProcessor : public ImageToTensorConverter { constexpr GLchar kExtractSubRectVertexShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -86,7 +86,7 @@ class GlProcessor : public ImageToTensorConverter { )"; constexpr GLchar kExtractSubRectFragBody[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) // Provided by kExtractSubRectVertexShader. in vec2 sample_coordinate; From e997a19289d85071775751b453aa2e1b982f3891 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:22:32 +0530 Subject: [PATCH 261/469] Added common utils and string helpers --- mediapipe/tasks/ios/common/utils/BUILD | 41 ++++++ .../ios/common/utils/sources/MPPCommonUtils.h | 78 ++++++++++ .../common/utils/sources/MPPCommonUtils.mm | 137 ++++++++++++++++++ .../common/utils/sources/NSString+Helpers.h | 28 ++++ .../common/utils/sources/NSString+Helpers.mm | 27 ++++ 5 files changed, 311 insertions(+) create mode 100644 mediapipe/tasks/ios/common/utils/BUILD create mode 100644 mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h create mode 100644 mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm create mode 100644 mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h create mode 100644 mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD new file mode 100644 index 000000000..f2ffda39e --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -0,0 +1,41 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommonUtils", + srcs = ["sources/MPPCommonUtils.mm"], + hdrs = ["sources/MPPCommonUtils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/ios/common:MPPCommon", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +objc_library( + name = "NSStringHelpers", + srcs = ["sources/NSString+Helpers.mm"], + hdrs = ["sources/NSString+Helpers.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], +) + diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h new file mode 100644 index 000000000..8a90856c7 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -0,0 +1,78 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include "mediapipe/tasks/cc/common.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Error domain of Mediapipe Task related errors. */ +extern NSString *const MPPTasksErrorDomain; + +/** Helper utility for the all tasks which encapsulates common functionality. */ +@interface MPPCommonUtils : NSObject + +/** + * Creates and saves an NSError in the Mediapipe task library domain, with the given code and + * description. + * + * @param code Error code. + * @param description Error description. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + */ ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description; + +/** + * Creates and saves an NSError with the given domain, code and description. + * + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + * @param domain Error domain. + * @param code Error code. + * @param description Error description. + */ ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description; + +/** + * Converts an absl status to an NSError. + * + * @param status absl status. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + */ ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error; + +/** + * Allocates a block of memory with the specified size and returns a pointer to it. If memory + * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it + * terminates program execution. + * + * @param memSize size of memory to be allocated + * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no + * error will be saved. + * + * @return Pointer to the allocated block of memory on successfull allocation. nil in case as + * error is encountered because of invalid memSize. If failure is due to any other reason, method + * terminates program execution. + */ ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm new file mode 100644 index 000000000..574f2ef9a --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -0,0 +1,137 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" + +#include + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl + +#include "mediapipe/tasks/cc/common.h" + +/** Error domain of MediaPipe task library errors. */ +NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; + +@implementation MPPCommonUtils + ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description { + [MPPCommonUtils createCustomError:error + withDomain:MPPTasksErrorDomain + code:code + description:description]; +} + ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description { + if (error) { + *error = [NSError errorWithDomain:domain + code:code + userInfo:@{NSLocalizedDescriptionKey : description}]; + } +} + ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error { + if (!memSize) { + [MPPCommonUtils createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:@"memSize cannot be zero."]; + return NULL; + } + + void *allocedMemory = malloc(memSize); + if (!allocedMemory) { + exit(-1); + } + + return allocedMemory; +} + ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError *_Nullable *)error { + if (status.ok()) { + return YES; + } + // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of + // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum + // stored in the payload is extracted here to later map to the appropriate error code to be + // returned. In cases where the enum is not stored in (payload is NULL or the payload string + // cannot be converted to an integer), we set the error code value to be 1 + // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify + // any errors not falling into other categories.) Since payload is of type absl::Cord that can be + // type cast into an absl::optional, we use the std::stoi function to convert it into + // an integer code if possible. + NSUInteger genericErrorCode = MPPTasksErrorCodeError; + NSUInteger errorCode; + try { + // Try converting payload to integer if payload is not empty. Otherwise convert a string + // signifying generic error code MPPTasksErrorCodeError to integer. + errorCode = + (NSUInteger)std::stoi(static_cast>( + status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload)) + .value_or(std::to_string(genericErrorCode))); + } catch (std::invalid_argument &e) { + // If non empty payload string cannot be converted to an integer. Set error code to 1(kError). + errorCode = MPPTasksErrorCodeError; + } + + // If errorCode is outside the range of enum values possible or is + // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign + // appropriate MPPTasksErrorCode in default cases. Note: + // The mapping to absl::Status::code() is done to generate a more specific error code than + // MPPTasksErrorCodeError in cases when the payload can't be mapped to + // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned + // without modification by Mediapipe cc library methods. + if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { + switch (status.code()) { + case absl::StatusCode::kInternal: + errorCode = MPPTasksErrorCodeError; + break; + case absl::StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case absl::StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeError; + break; + default: + errorCode = MPPTasksErrorCodeError; + break; + } + } + + // Creates the NSEror with the appropriate error + // MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one + // mapping with MediaPipeTasksStatus starting from the value 1(MPPTasksErrorCodeError) + // and hence will be correctly initialized if directly cast from the integer code derived from + // MediaPipeTasksStatus stored in its payload. MPPTasksErrorCode omits kOk = 0 of + // MediaPipeTasksStatusx. + // + // Stores a string including absl status code and message(if non empty) as the + // error message See + // https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514 + // for explanation. absl::Status::message() can also be used but not always + // guaranteed to be non empty. + NSString *description = [NSString + stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() + encoding:NSUTF8StringEncoding]; + [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; + return NO; +} + +@end diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h new file mode 100644 index 000000000..aac7485da --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include + +NS_ASSUME_NONNULL_BEGIN + +@interface NSString (Helpers) + +@property(readonly) std::string cppString; + ++ (NSString *)stringWithCppString:(std::string)text; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm new file mode 100644 index 000000000..183ed4365 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm @@ -0,0 +1,27 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +@implementation NSString (Helpers) + +- (std::string)cppString { + return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); +} + ++ (NSString *)stringWithCppString:(std::string)text { + return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]]; +} + +@end From 03bfbca53940d71b68c5c6e5c6a697abbf9fe5fe Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:22:44 +0530 Subject: [PATCH 262/469] Added classifier options --- .../tasks/ios/components/processors/BUILD | 24 +++++++++++ .../processors/sources/MPPClassifierOptions.h | 42 +++++++++++++++++++ .../processors/sources/MPPClassifierOptions.m | 40 ++++++++++++++++++ 3 files changed, 106 insertions(+) create mode 100644 mediapipe/tasks/ios/components/processors/BUILD create mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h create mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m diff --git a/mediapipe/tasks/ios/components/processors/BUILD b/mediapipe/tasks/ios/components/processors/BUILD new file mode 100644 index 000000000..6d1cfdf59 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/BUILD @@ -0,0 +1,24 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPClassifierOptions", + srcs = ["sources/MPPClassifierOptions.m"], + hdrs = ["sources/MPPClassifierOptions.h"], +) + diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h new file mode 100644 index 000000000..8c4981642 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -0,0 +1,42 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds settings for any single iOS Mediapipe classification task. + */ +NS_SWIFT_NAME(ClassifierOptions) +@interface MPPClassifierOptions : NSObject + +/** If set, all classes in this list will be filtered out from the results . */ +@property(nonatomic, copy) NSArray *labelDenyList; + +/** If set, all classes not in this list will be filtered out from the results . */ +@property(nonatomic, copy) NSArray *labelAllowList; + +/** Display names local for display names*/ +@property(nonatomic, copy) NSString *displayNamesLocale; + +/** Results with score threshold greater than this value are returned . */ +@property(nonatomic) float scoreThreshold; + +/** Limit to the number of classes that can be returned in results. */ +@property(nonatomic) NSInteger maxResults; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m new file mode 100644 index 000000000..52dce23e4 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -0,0 +1,40 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" + +@implementation MPPClassifierOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.maxResults = -1; + self.scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPClassifierOptions *classifierOptions = [[MPPClassifierOptions alloc] init]; + + classifierOptions.scoreThreshold = self.scoreThreshold; + classifierOptions.maxResults = self.maxResults; + classifierOptions.labelDenyList = self.labelDenyList; + classifierOptions.labelAllowList = self.labelAllowList; + classifierOptions.displayNamesLocale = self.displayNamesLocale; + + return classifierOptions; +} + +@end From c56ef735d7f8bf4fb7482c2b7dd01a61c3d0ffc4 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:22:57 +0530 Subject: [PATCH 263/469] Added classifier options helpers --- .../ios/components/processors/utils/BUILD | 29 ++++++++++++++ .../sources/MPPClassifierOptions+Helpers.h | 25 ++++++++++++ .../sources/MPPClassifierOptions+Helpers.mm | 38 +++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 mediapipe/tasks/ios/components/processors/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/components/processors/utils/BUILD b/mediapipe/tasks/ios/components/processors/utils/BUILD new file mode 100644 index 000000000..820c6bb56 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/BUILD @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPClassifierOptionsHelpers", + srcs = ["sources/MPPClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ] +) + diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h new file mode 100644 index 000000000..6644a6255 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h @@ -0,0 +1,25 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifierOptions (Helpers) +- (void)copyToProto: + (mediapipe::tasks::components::processors::proto::ClassifierOptions *)classifierOptionsProto; +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm new file mode 100644 index 000000000..25e657599 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -0,0 +1,38 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" + +namespace { +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} + +@implementation MPPClassifierOptions (Helpers) +- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + for (NSString *category in self.labelAllowList) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.labelDenyList) { + classifierOptionsProto->add_category_denylist(category.cppString); + } +} + +@end From 6d02108bf5244f190dc07e035913694864138467 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:23:29 +0530 Subject: [PATCH 264/469] Added task info --- .../tasks/ios/core/sources/MPPTaskInfo.h | 69 +++++++++ .../tasks/ios/core/sources/MPPTaskInfo.mm | 136 ++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskInfo.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h new file mode 100644 index 000000000..fca660fae --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -0,0 +1,69 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include "mediapipe/framework/calculator.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" + + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds all needed informaton to initialize a MediaPipe Task. + */ +@interface MPPTaskInfo : NSObject + +@property(nonatomic, copy, nonnull) NSString *taskGraphName; + +/** + * A task-specific options that is derived from MPPTaskOptions and confirms to + * MPPTaskOptionsProtocol. + */ +@property(nonatomic, copy) id taskOptions; + +/** + * List of task graph input stream info strings in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *inputStreams; + +/** + * List of task graph output stream info in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *outputStreams; + +/** + * If the task requires a flow limiter. + */ +@property(nonatomic) BOOL enableFlowLimiting; + ++ (instancetype)new NS_UNAVAILABLE; + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error; + +/** + * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. + */ +- (mediapipe::CalculatorGraphConfig)generateGraphConfig; + +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm new file mode 100644 index 000000000..7d2fd6f28 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -0,0 +1,136 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_options.pb.h" + +namespace { +using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; +using Node = ::mediapipe::CalculatorGraphConfig::Node; +using ::mediapipe::InputStreamInfo; +using ::mediapipe::CalculatorOptions; +using ::mediapipe::FlowLimiterCalculatorOptions; +} // namespace + +@implementation MPPTaskInfo + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error { + if (!taskGraphName || !inputStreams.count || !outputStreams.count) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description: + @"Task graph's name, input streams, and output streams should be non-empty."]; + } + + self = [super init]; + + if (self) { + _taskGraphName = taskGraphName; + _inputStreams = inputStreams; + _outputStreams = outputStreams; + _taskOptions = taskOptions; + _enableFlowLimiting = enableFlowLimiting; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] init]; + + taskInfo.taskGraphName = self.taskGraphName; + taskInfo.inputStreams = self.inputStreams; + taskInfo.outputStreams = self.outputStreams; + taskInfo.taskOptions = self.taskOptions; + taskInfo.enableFlowLimiting = self.enableFlowLimiting; + + return taskInfo; +} + +- (CalculatorGraphConfig)generateGraphConfig { + CalculatorGraphConfig graph_config; + + Node *task_subgraph_node = graph_config.add_node(); + task_subgraph_node->set_calculator(self.taskGraphName.cppString); + [self.taskOptions copyToProto:task_subgraph_node->mutable_options()]; + + for (NSString *outputStream in self.outputStreams) { + auto cpp_output_stream = std::string(outputStream.cppString); + task_subgraph_node->add_output_stream(cpp_output_stream); + graph_config.add_output_stream(cpp_output_stream); + } + + if (self.enableFlowLimiting) { + Node *flow_limit_calculator_node = graph_config.add_node(); + + flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + + InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); + input_stream_info->set_tag_index("FINISHED"); + input_stream_info->set_back_edge(true); + + FlowLimiterCalculatorOptions *flow_limit_calculator_options = + flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions::ext); + flow_limit_calculator_options->set_max_in_flight(1); + flow_limit_calculator_options->set_max_in_queue(1); + + for (NSString *inputStream in self.inputStreams) { + graph_config.add_input_stream(inputStream.cppString); + + NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; + flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + + NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; + task_subgraph_node->add_input_stream(taskInputStream.cppString); + + NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; + flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + } + + NSString *firstOutputStream = self.outputStreams[0]; + auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; + flow_limit_calculator_node->add_input_stream(finished_output_stream); + } else { + for (NSString *inputStream in self.inputStreams) { + auto cpp_input_stream = inputStream.cppString; + task_subgraph_node->add_input_stream(cpp_input_stream); + graph_config.add_input_stream(cpp_input_stream); + } + } + + return graph_config; +} + ++ (NSString *)stripTagIndex:(NSString *)tagIndexName { + return [tagIndexName componentsSeparatedByString:@":"][1]; +} + ++ (NSString *)addStreamNamePrefix:(NSString *)tagIndexName { + NSArray *splits = [tagIndexName componentsSeparatedByString:@":"]; + return [NSString stringWithFormat:@"%@:throttled_%@", splits[0], splits[1]]; +} + +@end From 64cf5e9b4e7208b3b974038d9e160e1509be1945 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:23:41 +0530 Subject: [PATCH 265/469] Added iOS task options protocol --- .../ios/core/sources/MPPTaskOptionsProtocol.h | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h new file mode 100644 index 000000000..c6f115451 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -0,0 +1,32 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include "mediapipe/framework/calculator_options.pb.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Any mediapipe task options should confirm to this protocol. + */ +@protocol MPPTaskOptionsProtocol + +/** + * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. + */ +- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END From e9fc3713f0bd0ab0fceb9ea07e78373fc8c50efd Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:23:51 +0530 Subject: [PATCH 266/469] Added iOS task runner --- .../tasks/ios/core/sources/MPPTaskRunner.h | 47 ++++++++++++++++ .../tasks/ios/core/sources/MPPTaskRunner.mm | 56 +++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskRunner.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h new file mode 100644 index 000000000..64e34b82e --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -0,0 +1,47 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" + + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner. + */ +@interface MPPTaskRunner : NSObject +/** + * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. + * + * @param graphConfig A mediapipe task graph config proto. + * + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error; + +- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; + +- (void)close; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm new file mode 100644 index 000000000..404f6c582 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -0,0 +1,56 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; +} // namespace + +@interface MPPTaskRunner () { + // Cpp Task Runner + std::unique_ptr _cppTaskRunner; +} +@end + +@implementation MPPTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super init]; + if (self) { + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + + if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { + return nil; + } + + _cppTaskRunner = std::move(taskRunnerResult.value()); + } + return self; +} + +- (absl::StatusOr)process:(const PacketMap&)packetMap { + return _cppTaskRunner->Process(packetMap); +} + +- (void)close { + _cppTaskRunner->Close(); +} + +@end From 4fedea60a93adb6ac9db50212b7f06f29758576e Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:24:02 +0530 Subject: [PATCH 267/469] Added text packet creator --- .../ios/core/sources/MPPTextPacketCreator.h | 26 +++++++++++++++++ .../ios/core/sources/MPPTextPacketCreator.mm | 29 +++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h new file mode 100644 index 000000000..03f946dd0 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h @@ -0,0 +1,26 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/packet.h" + +/* This class is an Objective-C wrapper around a MediaPipe graph object, and + * helps interface it with iOS technologies such as AVFoundation. + */ +@interface MPPTextPacketCreator : NSObject + ++ (mediapipe::Packet)createWithText:(NSString *)text; + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm new file mode 100644 index 000000000..ca86e7a0b --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +namespace { +using ::mediapipe::MakePacket; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPTextPacketCreator + ++ (Packet)createWithText:(NSString *)text { + return MakePacket(text.cppString); +} + +@end From ff901a80a5398276b04e03e561c9acf0892d3aaa Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:24:11 +0530 Subject: [PATCH 268/469] Added targets in core --- mediapipe/tasks/ios/core/BUILD | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 7b648945e..adc37d901 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -36,3 +36,56 @@ objc_library( srcs = ["sources/MPPTaskResult.m"], hdrs = ["sources/MPPTaskResult.h"], ) + +objc_library( + name = "MPPTaskOptionsProtocol", + hdrs = ["sources/MPPTaskOptionsProtocol.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + ], +) + +objc_library( + name = "MPPTaskInfo", + srcs = ["sources/MPPTaskInfo.mm"], + hdrs = ["sources/MPPTaskInfo.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + ":MPPTaskOptions", + ":MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/common:MPPCommon", + ], +) + +objc_library( + name = "MPPTextPacketCreator", + srcs = ["sources/MPPTextPacketCreator.mm"], + hdrs = ["sources/MPPTextPacketCreator.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPTaskRunner", + srcs = ["sources/MPPTaskRunner.mm"], + hdrs = ["sources/MPPTaskRunner.h"], + deps = [ + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) From ce0bc2b9acb9c11d0e54aabd8cb9430aedfc0c9b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 13:51:12 -0800 Subject: [PATCH 269/469] Internal change PiperOrigin-RevId: 496742964 --- .github/bot_config.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 74a60e4b9..8ad724168 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,5 +15,4 @@ # A list of assignees assignees: - - kuaashish - - ayushgdev + - sureshdagooglecom From a7b52d2c5281e82c208932ff2bedcf85356f868f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 14:34:40 -0800 Subject: [PATCH 270/469] Internal changes PiperOrigin-RevId: 496754449 --- mediapipe/model_maker/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 9b3c9f906..d7e4a950f 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,5 +1,5 @@ absl-py -mediapipe==0.9.1 +mediapipe==0.9.0.1 numpy opencv-python tensorflow>=2.10 From d2f738793c8ec5ad9b66aeec78fac74a76b37100 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 20 Dec 2022 15:15:24 -0800 Subject: [PATCH 271/469] Use uppercase options name for "delegate" PiperOrigin-RevId: 496764089 --- .../tasks/web/components/processors/base_options.test.ts | 6 +++--- mediapipe/tasks/web/components/processors/base_options.ts | 2 +- mediapipe/tasks/web/core/task_runner_options.d.ts | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts index 46c2277e9..6d58be68f 100644 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ b/mediapipe/tasks/web/components/processors/base_options.test.ts @@ -86,7 +86,7 @@ describe('convertBaseOptionsToProto()', () => { it('can enable CPU delegate', async () => { const baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'cpu', + delegate: 'CPU', }); expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); }); @@ -94,7 +94,7 @@ describe('convertBaseOptionsToProto()', () => { it('can enable GPU delegate', async () => { const baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'gpu', + delegate: 'GPU', }); expect(baseOptionsProto.toObject()).toEqual({ ...mockBytesResult, @@ -117,7 +117,7 @@ describe('convertBaseOptionsToProto()', () => { it('can reset delegate', async () => { let baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'gpu', + delegate: 'GPU', }); // Clear backend baseOptionsProto = diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index 16d562262..97b62b784 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -71,7 +71,7 @@ async function configureExternalFile( /** Configues the `acceleration` option. */ function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'gpu') { + if (options.delegate === 'GPU') { acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); } else { acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); diff --git a/mediapipe/tasks/web/core/task_runner_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts index aa0b4a028..5f23cd4bf 100644 --- a/mediapipe/tasks/web/core/task_runner_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -31,7 +31,7 @@ export declare interface BaseOptions { modelAssetBuffer?: Uint8Array|undefined; /** Overrides the default backend to use for the provided model. */ - delegate?: 'cpu'|'gpu'|undefined; + delegate?: 'CPU'|'GPU'|undefined; } /** Options to configure MediaPipe Tasks in general. */ From 64406a9bf27cd324e6856dbeb0f8b9c69d496ac7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 16:39:52 -0800 Subject: [PATCH 272/469] Internal change PiperOrigin-RevId: 496781536 --- mediapipe/calculators/core/BUILD | 14 ++--- mediapipe/calculators/image/BUILD | 4 +- mediapipe/calculators/util/BUILD | 2 +- mediapipe/framework/BUILD | 70 ++++++++++++------------ mediapipe/framework/formats/BUILD | 24 ++++---- mediapipe/framework/formats/motion/BUILD | 2 +- mediapipe/framework/stream_handler/BUILD | 14 ++--- mediapipe/framework/tool/BUILD | 18 +++--- mediapipe/gpu/BUILD | 14 ++--- 9 files changed, 81 insertions(+), 81 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 2c143a609..b3378a74e 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -567,7 +567,7 @@ cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], deps = [ - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -584,7 +584,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -762,7 +762,7 @@ cc_library( srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], deps = [ - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -786,7 +786,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -852,10 +852,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ + ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - "//mediapipe/calculators/core:counting_source_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -1302,7 +1302,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 530dd3d4a..9aae8cfbc 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -378,8 +378,8 @@ cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], deps = [ + ":scale_image_calculator_cc_proto", ":scale_image_utils", - "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -747,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", + ":image_transformation_calculator", ":warp_affine_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 1529ead8a..a679a80fd 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -22,8 +22,8 @@ cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], deps = [ + ":detections_to_rects_calculator", ":detections_to_rects_calculator_cc_proto", - "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 0dd694760..082ea9994 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -226,13 +226,13 @@ cc_library( ":mediapipe_internal", ], deps = [ + ":calculator_cc_proto", ":graph_service", + ":mediapipe_options_cc_proto", + ":packet_generator_cc_proto", ":packet_type", ":port", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -328,10 +328,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":calculator_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", + ":thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -391,6 +391,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -407,10 +408,9 @@ cc_library( ":packet_set", ":packet_type", ":port", + ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -466,6 +466,7 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ + ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -475,7 +476,6 @@ cc_library( ":packet", ":packet_set", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -583,7 +583,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:mediapipe_options_cc_proto", + ":mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -670,11 +670,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", + ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -784,12 +784,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", + ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -875,10 +875,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -896,13 +896,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", + ":packet_factory_cc_proto", ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", - "//mediapipe/framework:packet_factory_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1019,10 +1019,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ + ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1035,10 +1035,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1097,7 +1097,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1162,22 +1162,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", + ":status_handler_cc_proto", + ":stream_handler_cc_proto", ":subgraph", + ":thread_pool_executor_cc_proto", ":timestamp", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1202,11 +1202,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1233,6 +1233,7 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1242,7 +1243,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1256,11 +1256,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1368,6 +1368,7 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1376,7 +1377,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1403,6 +1403,7 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", + ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1410,13 +1411,12 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", + ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1481,12 +1481,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1630,8 +1630,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index fdd9b8909..f5a043f10 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - deps = ["//mediapipe/framework/formats:location_data_proto"], + deps = [":location_data_proto"], ) mediapipe_register_type( @@ -38,7 +38,7 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = ["//mediapipe/framework/formats:detection_cc_proto"], + deps = [":detection_cc_proto"], ) mediapipe_proto_library( @@ -105,8 +105,8 @@ cc_library( srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ + ":matrix_data_cc_proto", "//mediapipe/framework:port", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -142,7 +142,7 @@ cc_library( srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -166,8 +166,8 @@ cc_library( srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], deps = [ + ":image_format_cc_proto", ":image_frame", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -194,7 +194,7 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", + ":location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -245,7 +245,7 @@ cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", ], ) @@ -263,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ + ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -324,8 +324,8 @@ cc_library( "//conditions:default": [], }), deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:image_frame", + ":image_format_cc_proto", + ":image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -354,7 +354,7 @@ cc_library( hdrs = ["image_multi_pool.h"], deps = [ ":image", - "//mediapipe/framework/formats:image_frame_pool", + ":image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -390,7 +390,7 @@ cc_library( ], deps = [ ":image", - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index f1bbc0289..c9bb8b4ff 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -38,11 +38,11 @@ cc_library( srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], deps = [ + ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", - "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 01ef6ee86..68a9af52d 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -88,8 +88,8 @@ cc_library( srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], deps = [ + ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -110,8 +110,8 @@ cc_library( srcs = ["fixed_size_input_stream_handler.cc"], deps = [ ":default_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -159,13 +159,13 @@ cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], deps = [ + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -177,10 +177,10 @@ cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], deps = [ + ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -243,6 +243,7 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -251,7 +252,6 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -272,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -289,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 89cb802da..193343a90 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,6 +299,7 @@ mediapipe_cc_test( requires_full_emulation = False, deps = [ ":node_chain_subgraph_cc_proto", + ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", @@ -313,7 +314,6 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", - "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -422,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ + ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -485,13 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,6 +506,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", @@ -515,7 +516,6 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -661,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ + ":simulation_clock", "//mediapipe/framework:thread_pool_executor", - "//mediapipe/framework/tool:simulation_clock", ], ) @@ -789,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -805,6 +805,7 @@ cc_library( deps = [ ":container_util", ":options_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -814,7 +815,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -841,6 +841,7 @@ cc_library( ], deps = [ ":container_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -850,7 +851,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -893,6 +893,7 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", + ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -904,7 +905,6 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 009eb3f9e..cc5e50dfc 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -564,6 +564,7 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -571,7 +572,6 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", - "//mediapipe/gpu:gl_context_options_cc_proto", ], ) @@ -592,7 +592,7 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - "//mediapipe/gpu:gl_context_options_cc_proto", + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -833,10 +833,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", + ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -907,8 +907,8 @@ proto_library( srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) @@ -930,6 +930,7 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -937,7 +938,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", - "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -950,13 +950,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -966,8 +966,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) From 151e447614741f02185c94f4412a3ab665a16c17 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 20 Dec 2022 17:50:21 -0800 Subject: [PATCH 273/469] Internal changes PiperOrigin-RevId: 496793199 --- mediapipe/calculators/core/sequence_shift_calculator.cc | 6 ++++++ mediapipe/calculators/core/sequence_shift_calculator.proto | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index 66dbdef2e..026048b79 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -66,12 +66,16 @@ class SequenceShiftCalculator : public Node { // The number of packets or timestamps we need to store to output packet[i] at // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). int cache_size_; + bool emit_empty_packets_before_first_packet_ = false; }; MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { packet_offset_ = kOffset(cc).GetOr( cc->Options().packet_offset()); + emit_empty_packets_before_first_packet_ = + cc->Options() + .emit_empty_packets_before_first_packet(); cache_size_ = abs(packet_offset_); // An offset of zero is a no-op, but someone might still request it. if (packet_offset_ == 0) { @@ -96,6 +100,8 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { // Ready to output oldest packet with current timestamp. kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); packet_cache_.pop_front(); + } else if (emit_empty_packets_before_first_packet_) { + LOG(FATAL) << "Not supported yet"; } // Store current packet for later output. packet_cache_.push_back(kIn(cc).packet()); diff --git a/mediapipe/calculators/core/sequence_shift_calculator.proto b/mediapipe/calculators/core/sequence_shift_calculator.proto index 15b111d71..36b0bb959 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.proto +++ b/mediapipe/calculators/core/sequence_shift_calculator.proto @@ -23,4 +23,8 @@ message SequenceShiftCalculatorOptions { optional SequenceShiftCalculatorOptions ext = 107633927; } optional int32 packet_offset = 1 [default = -1]; + + // Emits empty packets before the first delayed packet is emitted. Takes + // effect only when packet offset is set to positive. + optional bool emit_empty_packets_before_first_packet = 2 [default = false]; } From 5c0f548f5f5b31d94b749456cdac306b5330dfa3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 20:51:23 -0800 Subject: [PATCH 274/469] Switches to tf.keras.optimizers.experimental.AdamW instead of the legacy AdamW. PiperOrigin-RevId: 496821354 --- .../text/text_classifier/text_classifier.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c285702d2..c4d3fdbe2 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -417,8 +417,22 @@ class _BertClassifier(TextClassifier): total_steps = self._hparams.steps_per_epoch * self._hparams.epochs warmup_steps = int(total_steps * 0.1) initial_lr = self._hparams.learning_rate - self._optimizer = optimization.create_optimizer(initial_lr, total_steps, - warmup_steps) + # Implements linear decay of the learning rate. + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=initial_lr, + decay_steps=total_steps, + end_learning_rate=0.0, + power=1.0) + if warmup_steps: + lr_schedule = optimization.WarmUp( + initial_learning_rate=initial_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=warmup_steps) + + self._optimizer = tf.keras.optimizers.experimental.AdamW( + lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0) + self._optimizer.exclude_from_weight_decay( + var_names=["LayerNorm", "layer_norm", "bias"]) def _save_vocab(self, vocab_filepath: str): tf.io.gfile.copy( From 1341720d6db044d2771eabe5d5574d67bb04a4f6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 21 Dec 2022 00:52:17 -0800 Subject: [PATCH 275/469] Internal change PiperOrigin-RevId: 496854337 --- mediapipe/framework/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 082ea9994..a4c9a520d 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -369,7 +369,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":graph_service", - "//mediapipe/framework:packet", + ":packet", "@com_google_absl//absl/status", ], ) @@ -379,7 +379,7 @@ cc_test( srcs = ["graph_service_manager_test.cc"], deps = [ ":graph_service_manager", - "//mediapipe/framework:packet", + ":packet", "//mediapipe/framework/port:gtest_main", ], ) From 714a6e555b106e7fd4de1b1e83d70e1c1c8570f3 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 21 Dec 2022 08:06:20 -0800 Subject: [PATCH 276/469] Enable creating mediapipe image c++ packet directly from an Android media image object when its format is RGBA_8888. PiperOrigin-RevId: 496923491 --- .../mediapipe/framework/AndroidPacketCreator.java | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 05700ba17..fc1e5484e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -15,10 +15,13 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; +import android.graphics.PixelFormat; +import android.media.Image; import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImageProperties; +import com.google.mediapipe.framework.image.MediaImageExtractor; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -97,7 +100,17 @@ public class AndroidPacketCreator extends PacketCreator { } return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); } - + if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) { + Image mediaImage = MediaImageExtractor.extract(image); + if (mediaImage.getFormat() != PixelFormat.RGBA_8888) { + throw new UnsupportedOperationException("Android media image must use RGBA_8888 config."); + } + return createImage( + mediaImage.getPlanes()[0].getBuffer(), + mediaImage.getWidth(), + mediaImage.getHeight(), + /* numChannels= */ 4); + } // Unsupported type. throw new UnsupportedOperationException( "Unsupported Image container type: " + properties.getStorageType()); From c8b8d1fe6b04ef906b7ef3956fdff266c2704228 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 21 Dec 2022 11:08:01 -0800 Subject: [PATCH 277/469] Remove scripts for building MediaPipe Python 3.7 wheels. PiperOrigin-RevId: 496962729 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b072a850e..992430cf1 100644 --- a/setup.py +++ b/setup.py @@ -490,10 +490,10 @@ setuptools.setup( 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', From ae28948ca150fd3a801c2ef1387b460151083204 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:49:24 +0530 Subject: [PATCH 278/469] Marked designated initializers --- mediapipe/tasks/ios/core/sources/MPPTaskInfo.h | 2 +- mediapipe/tasks/ios/core/sources/MPPTaskRunner.h | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index fca660fae..4c01787a8 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -55,7 +55,7 @@ NS_ASSUME_NONNULL_BEGIN outputStreams:(NSArray *)outputStreams taskOptions:(id)taskOptions enableFlowLimiting:(BOOL)enableFlowLimiting - error:(NSError **)error; + error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 64e34b82e..e07cb344d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -24,6 +24,7 @@ NS_ASSUME_NONNULL_BEGIN * This class is used to create and call appropriate methods on the C++ Task Runner. */ @interface MPPTaskRunner : NSObject + /** * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. * @@ -32,7 +33,7 @@ NS_ASSUME_NONNULL_BEGIN * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. */ - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - error:(NSError **)error; + error:(NSError **)error NS_DESIGNATED_INITIALIZER; - (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; From 481f4e960e009df7431f7281312e985507428c87 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:49:44 +0530 Subject: [PATCH 279/469] Updated comments --- mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h index c6f115451..44fba4c0b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -18,12 +18,12 @@ NS_ASSUME_NONNULL_BEGIN /** - * Any mediapipe task options should confirm to this protocol. + * Any MediaPipe task options should confirm to this protocol. */ @protocol MPPTaskOptionsProtocol /** - * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. + * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto. */ - (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; From 2943d1668e0e0c66dad9a0e6626dc1c82e38cd3d Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:51:20 +0530 Subject: [PATCH 280/469] Updated comments --- mediapipe/tasks/ios/core/sources/MPPTaskRunner.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index e07cb344d..9dfef02e1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -23,6 +23,7 @@ NS_ASSUME_NONNULL_BEGIN /** * This class is used to create and call appropriate methods on the C++ Task Runner. */ + @interface MPPTaskRunner : NSObject /** From 20f2e136c520b937d81a8241aec7a4ca869d3f70 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:59:22 +0530 Subject: [PATCH 281/469] Updated empty spaces --- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index 25e657599..db7fa6bfd 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -20,12 +20,16 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto } @implementation MPPClassifierOptions (Helpers) + - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + for (NSString *category in self.labelAllowList) { classifierOptionsProto->add_category_allowlist(category.cppString); } From 1491b3f5a2da5ac3415edd8cd946e0e9b639887b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:00:36 +0530 Subject: [PATCH 282/469] Updated comments --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 8a90856c7..d2e6067d5 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -24,7 +24,7 @@ extern NSString *const MPPTasksErrorDomain; @interface MPPCommonUtils : NSObject /** - * Creates and saves an NSError in the Mediapipe task library domain, with the given code and + * Creates and saves an NSError in the MediPipe task library domain, with the given code and * description. * * @param code Error code. From 1de369417572ebb09e26c72d8a6fa3e5f7685795 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:02:07 +0530 Subject: [PATCH 283/469] Updated comments --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index d2e6067d5..1a44ee45a 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -51,9 +51,9 @@ extern NSString *const MPPTasksErrorDomain; description:(NSString *)description; /** - * Converts an absl status to an NSError. + * Converts an absl::Status to an NSError. * - * @param status absl status. + * @param status absl::Status. * @param error Pointer to the memory location where the created error should be saved. If `nil`, * no error will be saved. */ @@ -68,7 +68,7 @@ extern NSString *const MPPTasksErrorDomain; * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no * error will be saved. * - * @return Pointer to the allocated block of memory on successfull allocation. nil in case as + * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as * error is encountered because of invalid memSize. If failure is due to any other reason, method * terminates program execution. */ From 99c11ff9743fab7799bfd7db9e2e63755fe4a123 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:03:39 +0530 Subject: [PATCH 284/469] Updated comments --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 1a44ee45a..407d87aba 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -61,7 +61,7 @@ extern NSString *const MPPTasksErrorDomain; /** * Allocates a block of memory with the specified size and returns a pointer to it. If memory - * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it + * cannot be allocated because of an invalid `memSize`, it saves an error. In other cases, it * terminates program execution. * * @param memSize size of memory to be allocated @@ -69,7 +69,7 @@ extern NSString *const MPPTasksErrorDomain; * error will be saved. * * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as - * error is encountered because of invalid memSize. If failure is due to any other reason, method + * error is encountered because of invalid `memSize`. If failure is due to any other reason, method * terminates program execution. */ + (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; From 7ae4b7e6394b5e75315fb46b4ee9c44b6e02ecc1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:05:01 +0530 Subject: [PATCH 285/469] Updated error domain --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 574f2ef9a..4d4880a87 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -24,7 +24,7 @@ #include "mediapipe/tasks/cc/common.h" /** Error domain of MediaPipe task library errors. */ -NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; +NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; @implementation MPPCommonUtils @@ -68,7 +68,7 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of + // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum // stored in the payload is extracted here to later map to the appropriate error code to be // returned. In cases where the enum is not stored in (payload is NULL or the payload string From 54d36dfedad1d2e84f680fb69defb13b6eae45b9 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:05:50 +0530 Subject: [PATCH 286/469] Update MPPClassifierOptions.h --- .../ios/components/processors/sources/MPPClassifierOptions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 8c4981642..b31dadb63 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -17,7 +17,7 @@ NS_ASSUME_NONNULL_BEGIN /** - * Holds settings for any single iOS Mediapipe classification task. + * Holds settings for any single iOS MediaPipe classification task. */ NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject From 673b38dfe87c35504ac81f5b29935ab6b25beaa1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:08:13 +0530 Subject: [PATCH 287/469] Updated comments --- .../processors/sources/MPPClassifierOptions.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index b31dadb63..d6b9a9582 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -22,16 +22,18 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject -/** If set, all classes in this list will be filtered out from the results . */ +/** If set, all classes in this list will be filtered out from the results. */ @property(nonatomic, copy) NSArray *labelDenyList; -/** If set, all classes not in this list will be filtered out from the results . */ +/** If set, all classes not in this list will be filtered out from the results. */ @property(nonatomic, copy) NSArray *labelAllowList; -/** Display names local for display names*/ +/** The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ @property(nonatomic, copy) NSString *displayNamesLocale; -/** Results with score threshold greater than this value are returned . */ +/** Results with score threshold greater than this value are returned. */ @property(nonatomic) float scoreThreshold; /** Limit to the number of classes that can be returned in results. */ From 66ee8d47c0d13c4f0a4f4ee91bde7bf570fbaa61 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:10:07 +0530 Subject: [PATCH 288/469] Resorted options --- .../processors/sources/MPPClassifierOptions.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index d6b9a9582..0c22ed9de 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -22,22 +22,22 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject -/** If set, all classes in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelDenyList; - -/** If set, all classes not in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelAllowList; - /** The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. */ @property(nonatomic, copy) NSString *displayNamesLocale; +/** Limit to the number of classes that can be returned in results. */ +@property(nonatomic) NSInteger maxResults; + /** Results with score threshold greater than this value are returned. */ @property(nonatomic) float scoreThreshold; -/** Limit to the number of classes that can be returned in results. */ -@property(nonatomic) NSInteger maxResults; +/** If set, all classes not in this list will be filtered out from the results. */ +@property(nonatomic, copy) NSArray *labelAllowList; + +/** If set, all classes in this list will be filtered out from the results. */ +@property(nonatomic, copy) NSArray *labelDenyList; @end From e1dfcf03cf41f0f9519206ea4fa97f255161191f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:12:34 +0530 Subject: [PATCH 289/469] Updated comments in MPPClassifierOptions.h --- .../ios/components/processors/sources/MPPClassifierOptions.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 0c22ed9de..371472cab 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -30,7 +30,9 @@ NS_SWIFT_NAME(ClassifierOptions) /** Limit to the number of classes that can be returned in results. */ @property(nonatomic) NSInteger maxResults; -/** Results with score threshold greater than this value are returned. */ +/** Score threshold to override the one provided in the model metadata (if any). + * Results below this value are rejected. + */ @property(nonatomic) float scoreThreshold; /** If set, all classes not in this list will be filtered out from the results. */ From c185dc9ad7ba33844fac9560f33c59fb2c9e4ad6 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:19:01 +0530 Subject: [PATCH 290/469] Renamed label to category in classifier options --- .../ios/components/processors/sources/MPPClassifierOptions.h | 4 ++-- .../ios/components/processors/sources/MPPClassifierOptions.m | 4 ++-- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 371472cab..0f0abe398 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -36,10 +36,10 @@ NS_SWIFT_NAME(ClassifierOptions) @property(nonatomic) float scoreThreshold; /** If set, all classes not in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelAllowList; +@property(nonatomic, copy) NSArray *categoryAllowList; /** If set, all classes in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelDenyList; +@property(nonatomic, copy) NSArray *categoryDenyList; @end diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index 52dce23e4..1d9191802 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -30,8 +30,8 @@ classifierOptions.scoreThreshold = self.scoreThreshold; classifierOptions.maxResults = self.maxResults; - classifierOptions.labelDenyList = self.labelDenyList; - classifierOptions.labelAllowList = self.labelAllowList; + classifierOptions.categoryDenyList = self.categoryDenyList; + classifierOptions.categoryAllowList = self.categoryAllowList; classifierOptions.displayNamesLocale = self.displayNamesLocale; return classifierOptions; diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index db7fa6bfd..3d8397efa 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -30,11 +30,11 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto classifierOptionsProto->set_score_threshold(self.scoreThreshold); - for (NSString *category in self.labelAllowList) { + for (NSString *category in self.categoryAllowList) { classifierOptionsProto->add_category_allowlist(category.cppString); } - for (NSString *category in self.labelDenyList) { + for (NSString *category in self.categoryDenyList) { classifierOptionsProto->add_category_denylist(category.cppString); } } From 20c3388ab68c11082de43aff825762686b3bc8e1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:59:38 +0530 Subject: [PATCH 291/469] Updated category allowlist and denylist names --- .../ios/components/processors/sources/MPPClassifierOptions.h | 4 ++-- .../ios/components/processors/sources/MPPClassifierOptions.m | 4 ++-- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 0f0abe398..e95de89e4 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -36,10 +36,10 @@ NS_SWIFT_NAME(ClassifierOptions) @property(nonatomic) float scoreThreshold; /** If set, all classes not in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *categoryAllowList; +@property(nonatomic, copy) NSArray *categoryAllowlist; /** If set, all classes in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *categoryDenyList; +@property(nonatomic, copy) NSArray *categoryDenylist; @end diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index 1d9191802..accb6c7dd 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -30,8 +30,8 @@ classifierOptions.scoreThreshold = self.scoreThreshold; classifierOptions.maxResults = self.maxResults; - classifierOptions.categoryDenyList = self.categoryDenyList; - classifierOptions.categoryAllowList = self.categoryAllowList; + classifierOptions.categoryDenylist = self.categoryDenylist; + classifierOptions.categoryAllowlist = self.categoryAllowlist; classifierOptions.displayNamesLocale = self.displayNamesLocale; return classifierOptions; diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index 3d8397efa..81fe57d13 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -30,11 +30,11 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto classifierOptionsProto->set_score_threshold(self.scoreThreshold); - for (NSString *category in self.categoryAllowList) { + for (NSString *category in self.categoryAllowlist) { classifierOptionsProto->add_category_allowlist(category.cppString); } - for (NSString *category in self.categoryDenyList) { + for (NSString *category in self.categoryDenylist) { classifierOptionsProto->add_category_denylist(category.cppString); } } From b4a7644428ac05af45dc737bd1002b6a8f6154cc Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 02:01:04 +0530 Subject: [PATCH 292/469] Updated comments --- .../ios/components/processors/sources/MPPClassifierOptions.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index e95de89e4..348e94e96 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -27,7 +27,10 @@ NS_SWIFT_NAME(ClassifierOptions) */ @property(nonatomic, copy) NSString *displayNamesLocale; -/** Limit to the number of classes that can be returned in results. */ +/** The maximum number of top-scored classification results to return. If < 0, + * all available results will be returned. If 0, an invalid argument error is + * returned. + */ @property(nonatomic) NSInteger maxResults; /** Score threshold to override the one provided in the model metadata (if any). From e559613b9de8d73e0d4956688561174b58e2dcb9 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 02:02:48 +0530 Subject: [PATCH 293/469] Updated comments in MPPClassifierOptions.h --- .../processors/sources/MPPClassifierOptions.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 348e94e96..7bf5744f7 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -38,10 +38,16 @@ NS_SWIFT_NAME(ClassifierOptions) */ @property(nonatomic) float scoreThreshold; -/** If set, all classes not in this list will be filtered out from the results. */ +/** The allowlist of category names. If non-empty, detection results whose + * category name is not in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryDenylist. + */ @property(nonatomic, copy) NSArray *categoryAllowlist; -/** If set, all classes in this list will be filtered out from the results. */ +/** The denylist of category names. If non-empty, detection results whose + * category name is in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryAllowlist. + */ @property(nonatomic, copy) NSArray *categoryDenylist; @end From 69b6d9d970a9eae8d7c9e085201ba888ef4ef54b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 21 Dec 2022 17:39:54 -0800 Subject: [PATCH 294/469] Internal change PiperOrigin-RevId: 497043596 --- mediapipe/web/graph_runner/graph_runner.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index a9bb979af..ef866bc91 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -1028,7 +1028,9 @@ export class GraphRunner { // Set up our TS listener to receive any packets for this stream, and // additionally reformat our Uint8Array into a Float32Array for the user. this.setListener(outputStreamName, (data: Uint8Array) => { - const floatArray = new Float32Array(data.buffer); // Should be very fast + // Should be very fast + const floatArray = + new Float32Array(data.buffer, data.byteOffset, data.length / 4); callbackFcn(floatArray); }); From e47256ae55af3921d0878cf131c32625a2500082 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 12:10:23 +0530 Subject: [PATCH 295/469] Clearing proto before assigining new values in MPPClassifierOptions Helpers --- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index 81fe57d13..efe9572e1 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -22,6 +22,8 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto @implementation MPPClassifierOptions (Helpers) - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { + classifierOptionsProto->Clear(); + if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } From 613ed588908ac3bd39b48bf05e21c2fa52eeb9ad Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 12:16:33 +0530 Subject: [PATCH 296/469] Inverted condition check in MPPTaskInfo --- .../tasks/ios/core/sources/MPPTaskInfo.mm | 67 ++++++++++--------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index 7d2fd6f28..be3c8cbf7 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -24,9 +24,9 @@ namespace { using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using Node = ::mediapipe::CalculatorGraphConfig::Node; -using ::mediapipe::InputStreamInfo; using ::mediapipe::CalculatorOptions; using ::mediapipe::FlowLimiterCalculatorOptions; +using ::mediapipe::InputStreamInfo; } // namespace @implementation MPPTaskInfo @@ -82,45 +82,46 @@ using ::mediapipe::FlowLimiterCalculatorOptions; graph_config.add_output_stream(cpp_output_stream); } - if (self.enableFlowLimiting) { - Node *flow_limit_calculator_node = graph_config.add_node(); - - flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); - - InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); - input_stream_info->set_tag_index("FINISHED"); - input_stream_info->set_back_edge(true); - - FlowLimiterCalculatorOptions *flow_limit_calculator_options = - flow_limit_calculator_node->mutable_options()->MutableExtension( - FlowLimiterCalculatorOptions::ext); - flow_limit_calculator_options->set_max_in_flight(1); - flow_limit_calculator_options->set_max_in_queue(1); - - for (NSString *inputStream in self.inputStreams) { - graph_config.add_input_stream(inputStream.cppString); - - NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; - flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); - - NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; - task_subgraph_node->add_input_stream(taskInputStream.cppString); - - NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; - flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); - } - - NSString *firstOutputStream = self.outputStreams[0]; - auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; - flow_limit_calculator_node->add_input_stream(finished_output_stream); - } else { + if (!self.enableFlowLimiting) { for (NSString *inputStream in self.inputStreams) { auto cpp_input_stream = inputStream.cppString; task_subgraph_node->add_input_stream(cpp_input_stream); graph_config.add_input_stream(cpp_input_stream); } + return graph_config; } + Node *flow_limit_calculator_node = graph_config.add_node(); + + flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + + InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); + input_stream_info->set_tag_index("FINISHED"); + input_stream_info->set_back_edge(true); + + FlowLimiterCalculatorOptions *flow_limit_calculator_options = + flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions::ext); + flow_limit_calculator_options->set_max_in_flight(1); + flow_limit_calculator_options->set_max_in_queue(1); + + for (NSString *inputStream in self.inputStreams) { + graph_config.add_input_stream(inputStream.cppString); + + NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; + flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + + NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; + task_subgraph_node->add_input_stream(taskInputStream.cppString); + + NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; + flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + } + + NSString *firstOutputStream = self.outputStreams[0]; + auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; + flow_limit_calculator_node->add_input_stream(finished_output_stream); + return graph_config; } From 48eeae4d9d3582661f002ddc2424e3e6c8cdd512 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 12:16:43 +0530 Subject: [PATCH 297/469] Formatted code --- mediapipe/tasks/ios/core/sources/MPPTaskInfo.h | 1 - mediapipe/tasks/ios/core/sources/MPPTaskRunner.h | 5 +++-- mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index 4c01787a8..ae4c9eba1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -17,7 +17,6 @@ #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" - NS_ASSUME_NONNULL_BEGIN /** diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 9dfef02e1..6561e136d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -17,7 +17,6 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" - NS_ASSUME_NONNULL_BEGIN /** @@ -36,7 +35,9 @@ NS_ASSUME_NONNULL_BEGIN - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig error:(NSError **)error NS_DESIGNATED_INITIALIZER; -- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; +- (absl::StatusOr) + process:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; - (void)close; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index 404f6c582..e08d0bc1b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -45,7 +45,7 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return self; } -- (absl::StatusOr)process:(const PacketMap&)packetMap { +- (absl::StatusOr)process:(const PacketMap &)packetMap { return _cppTaskRunner->Process(packetMap); } From 967384160524a7be56da549f17abd129493ada78 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 09:47:35 -0800 Subject: [PATCH 298/469] Internal visibility update PiperOrigin-RevId: 497185157 --- mediapipe/framework/deps/BUILD | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 27bc105c8..7ff004f1e 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,9 +20,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = [ - "//mediapipe:__subpackages__", -]) +package_group( + name = "mediapipe_internal", + packages = [ + "//mediapipe/...", + ], +) + +package(default_visibility = ["mediapipe_internal"]) bzl_library( name = "expand_template_bzl", @@ -214,6 +219,9 @@ cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], + visibility = [ + "mediapipe_internal", + ], deps = [ ":registration_token", "//mediapipe/framework/port:logging", From 5b90afda701d1ddb91a435f064507b43636ea966 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 10:19:59 -0800 Subject: [PATCH 299/469] Internal change PiperOrigin-RevId: 497191969 --- mediapipe/framework/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index a4c9a520d..83346dad1 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1060,7 +1060,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_framework", - "//mediapipe/framework:test_calculators_cc_proto", + ":test_calculators_cc_proto", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", From 36f054dfbe391b450aeb11bfc4b71e962644b72d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 10:41:03 -0800 Subject: [PATCH 300/469] Internal model maker change PiperOrigin-RevId: 497196512 --- .../model_maker/python/text/text_classifier/text_classifier.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c4d3fdbe2..1a338e345 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -33,7 +33,6 @@ from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer -from official.nlp import optimization def _validate(options: text_classifier_options.TextClassifierOptions): @@ -424,7 +423,7 @@ class _BertClassifier(TextClassifier): end_learning_rate=0.0, power=1.0) if warmup_steps: - lr_schedule = optimization.WarmUp( + lr_schedule = model_util.WarmUp( initial_learning_rate=initial_lr, decay_schedule_fn=lr_schedule, warmup_steps=warmup_steps) From 5a71b551e5ad4b85aa18bb23d994fe09b753f0f9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 15:29:18 -0800 Subject: [PATCH 301/469] Remove duplicate and non-public api for model_maker PiperOrigin-RevId: 497251246 --- mediapipe/model_maker/__init__.py | 3 +++ .../python/text/text_classifier/__init__.py | 9 ++++++++ .../python/vision/gesture_recognizer/BUILD | 2 ++ .../vision/gesture_recognizer/__init__.py | 9 ++++++++ .../gesture_recognizer_test.py | 22 ++++++++++--------- .../python/vision/image_classifier/BUILD | 2 ++ .../vision/image_classifier/__init__.py | 9 ++++++++ .../image_classifier/image_classifier_test.py | 10 +++++---- 8 files changed, 52 insertions(+), 14 deletions(-) diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 9899a145b..b37088764 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -17,3 +17,6 @@ from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.text import text_classifier + +# Remove duplicated and non-public API +del python diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py index 618e51645..697461969 100644 --- a/mediapipe/model_maker/python/text/text_classifier/__init__.py +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions SupportedModels = model_spec.SupportedModels TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier_options.TextClassifierOptions + +# Remove duplicated and non-public API +del hyperparameters +del dataset +del model_options +del model_spec +del preprocessor # pylint: disable=undefined-variable +del text_classifier +del text_classifier_options diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 9123e36b0..cbdff7cf3 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -146,6 +146,8 @@ py_test( tags = ["notsan"], deps = [ ":gesture_recognizer_import", + ":hyperparameters", + ":model_options", "//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py index dc6923fac..a302e8d79 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py @@ -25,3 +25,12 @@ HParams = hyperparameters.HParams Dataset = dataset.Dataset HandDataPreprocessingParams = dataset.HandDataPreprocessingParams GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions + +# Remove duplicated and non-public API +del constants # pylint: disable=undefined-variable +del dataset +del gesture_recognizer +del gesture_recognizer_options +del hyperparameters +del metadata_writer # pylint: disable=undefined-variable +del model_options diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 08fda4fea..4fdb74225 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -23,6 +23,8 @@ import tensorflow as tf from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters +from mediapipe.model_maker.python.vision.gesture_recognizer import model_options from mediapipe.tasks.python.test import test_utils _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' @@ -48,11 +50,11 @@ class GestureRecognizerTest(tf.test.TestCase): self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -64,11 +66,11 @@ class GestureRecognizerTest(tf.test.TestCase): tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) def test_gesture_recognizer_model_layer_widths(self, mock_dense): layer_widths = [64, 32] - model_options = gesture_recognizer.ModelOptions(layer_widths=layer_widths) + mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths) hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -87,11 +89,11 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) def test_export_gesture_recognizer_model(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -128,12 +130,12 @@ class GestureRecognizerTest(tf.test.TestCase): self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( - gesture_recognizer.hyperparameters, + hyperparameters, 'HParams', autospec=True, return_value=gesture_recognizer.HParams(epochs=1)) @unittest_mock.patch.object( - gesture_recognizer.model_options, + model_options, 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) @@ -148,11 +150,11 @@ class GestureRecognizerTest(tf.test.TestCase): mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index 29ae189e9..d7c47a359 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -121,7 +121,9 @@ py_library( srcs = ["image_classifier_test.py"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"], deps = [ + ":hyperparameters", ":image_classifier_import", + ":model_options", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py index 3d0543cd2..0f964ef66 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/__init__.py +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -27,3 +27,12 @@ ModelOptions = model_options.ImageClassifierModelOptions ModelSpec = model_spec.ModelSpec SupportedModels = model_spec.SupportedModels ImageClassifierOptions = image_classifier_options.ImageClassifierOptions + +# Remove duplicated and non-public API +del dataset +del hyperparameters +del image_classifier +del image_classifier_options +del model_options +del model_spec +del train_image_classifier_lib # pylint: disable=undefined-variable diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 252659edc..6ca21d334 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -24,6 +24,8 @@ import numpy as np import tensorflow as tf from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters +from mediapipe.model_maker.python.vision.image_classifier import model_options from mediapipe.tasks.python.test import test_utils @@ -159,15 +161,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertGreaterEqual(accuracy, threshold) @unittest_mock.patch.object( - image_classifier.hyperparameters, + hyperparameters, 'HParams', autospec=True, - return_value=image_classifier.HParams(epochs=1)) + return_value=hyperparameters.HParams(epochs=1)) @unittest_mock.patch.object( - image_classifier.model_options, + model_options, 'ImageClassifierModelOptions', autospec=True, - return_value=image_classifier.ModelOptions()) + return_value=model_options.ImageClassifierModelOptions()) def test_create_hparams_and_model_options_if_none_in_image_classifier_options( self, mock_hparams, mock_model_options): options = image_classifier.ImageClassifierOptions( From 557cd050f3bf079266aaa7b88987a2cab5ab9ab3 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Thu, 22 Dec 2022 16:25:35 -0800 Subject: [PATCH 302/469] Deprecating RealTimeFlowLimiterCalculator in favor of FlowLimiterCalculator. PiperOrigin-RevId: 497260577 --- .../calculators/core/real_time_flow_limiter_calculator.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc index ef3cb9896..e3c92ba52 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -76,7 +76,11 @@ constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; // } // output_stream: "gated_frames" // } -class RealTimeFlowLimiterCalculator : public CalculatorBase { +// +// Please use FlowLimiterCalculator, which replaces this calculator and +// defines a few additional configuration options. +class ABSL_DEPRECATED("Use FlowLimiterCalculator instead.") + RealTimeFlowLimiterCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { int num_data_streams = cc->Inputs().NumEntries(""); From 5a5ff5393a7bfd9e76f7c3c867957eb18c48f80e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 17:29:23 -0800 Subject: [PATCH 303/469] Internal change PiperOrigin-RevId: 497269082 --- mediapipe/framework/api2/builder.h | 2 +- mediapipe/framework/api2/packet.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 19273bf44..2a98c4166 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -398,7 +398,7 @@ template class Node; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Node()->Node; +explicit Node() -> Node; #endif // C++17 template <> diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 7933575d3..b1ebb0410 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -181,7 +181,7 @@ template class Packet; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Packet()->Packet; +explicit Packet() -> Packet; #endif // C++17 template <> From 175aff9be8ca719257e15355ecc1b682e7e4e299 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 27 Dec 2022 11:24:50 -0800 Subject: [PATCH 304/469] Update list of issue assignments PiperOrigin-RevId: 498003950 --- .github/bot_config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev From 7e36a5e2ae8c66ef9717d399fa4004f448dde13f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 28 Dec 2022 11:22:52 -0800 Subject: [PATCH 305/469] Set filecmp.cmp(shallow=False) in model_maker unit tests. PiperOrigin-RevId: 498218578 --- .../python/text/text_classifier/text_classifier_test.py | 6 ++++-- .../python/vision/image_classifier/image_classifier_test.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 7a30d19fd..d2edb78bc 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -72,8 +72,10 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertTrue( - filecmp.cmp(output_metadata_file, - self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) + filecmp.cmp( + output_metadata_file, + self._AVERAGE_WORD_EMBEDDING_JSON_FILE, + shallow=False)) def test_create_and_train_bert(self): train_data, validation_data = self._get_data() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 6ca21d334..14c67d831 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,7 +135,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) - self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + self.assertTrue( + filecmp.cmp( + output_metadata_file, expected_metadata_file, shallow=False)) def test_continual_training_by_loading_checkpoint(self): mock_stdout = io.StringIO() From 9580f045710327b7a22d738b911af70121e2a79a Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 28 Dec 2022 13:57:20 -0800 Subject: [PATCH 306/469] Apply most graph options synchronously PiperOrigin-RevId: 498244085 --- .../audio_classifier/audio_classifier.ts | 7 +- .../audio_classifier/audio_classifier_test.ts | 3 +- .../audio/audio_embedder/audio_embedder.ts | 7 +- .../audio_embedder/audio_embedder_test.ts | 3 +- .../tasks/web/components/processors/BUILD | 26 --- .../processors/base_options.test.ts | 127 --------------- .../web/components/processors/base_options.ts | 80 ---------- mediapipe/tasks/web/core/BUILD | 5 +- mediapipe/tasks/web/core/task_runner.ts | 75 ++++++++- mediapipe/tasks/web/core/task_runner_test.ts | 148 +++++++++++++++++- .../text/text_classifier/text_classifier.ts | 7 +- .../text_classifier/text_classifier_test.ts | 3 +- .../web/text/text_embedder/text_embedder.ts | 7 +- .../text/text_embedder/text_embedder_test.ts | 3 +- mediapipe/tasks/web/vision/core/BUILD | 1 + .../vision/core/vision_task_runner.test.ts | 32 ++-- .../web/vision/core/vision_task_runner.ts | 4 +- .../gesture_recognizer/gesture_recognizer.ts | 8 +- .../gesture_recognizer_test.ts | 3 +- .../vision/hand_landmarker/hand_landmarker.ts | 8 +- .../hand_landmarker/hand_landmarker_test.ts | 3 +- .../image_classifier/image_classifier.ts | 7 +- .../image_classifier/image_classifier_test.ts | 3 +- .../vision/image_embedder/image_embedder.ts | 7 +- .../image_embedder/image_embedder_test.ts | 3 +- .../vision/object_detector/object_detector.ts | 8 +- .../object_detector/object_detector_test.ts | 3 +- 27 files changed, 280 insertions(+), 311 deletions(-) delete mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts delete mode 100644 mediapipe/tasks/web/components/processors/base_options.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 7bfca680a..51573f50a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner { * * @param options The options for the audio classifier. */ - override async setOptions(options: AudioClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts index d5c0a9429..2089f184f 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -79,7 +79,8 @@ describe('AudioClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioClassifier = new AudioClassifierFake(); - await audioClassifier.setOptions({}); // Initialize graph + await audioClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 246cba883..6a4b8ce39 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner { * * @param options The options for the audio embedder. */ - override async setOptions(options: AudioEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index 2f605ff98..dde61a6e9 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -70,7 +70,8 @@ describe('AudioEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioEmbedder = new AudioEmbedderFake(); - await audioEmbedder.setOptions({}); // Initialize graph + await audioEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', () => { diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 148a08238..cab24293d 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -103,29 +103,3 @@ jasmine_node_test( name = "embedder_options_test", deps = [":embedder_options_test_lib"], ) - -mediapipe_ts_library( - name = "base_options", - srcs = [ - "base_options.ts", - ], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", - "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/tasks/web/core", - ], -) - -mediapipe_ts_library( - name = "base_options_test_lib", - testonly = True, - srcs = ["base_options.test.ts"], - deps = [":base_options"], -) - -jasmine_node_test( - name = "base_options_test", - deps = [":base_options_test_lib"], -) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts deleted file mode 100644 index 6d58be68f..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import 'jasmine'; - -// Placeholder for internal dependency on encodeByteArray -// Placeholder for internal dependency on trusted resource URL builder - -import {convertBaseOptionsToProto} from './base_options'; - -describe('convertBaseOptionsToProto()', () => { - const mockBytes = new Uint8Array([0, 1, 2, 3]); - const mockBytesResult = { - modelAsset: { - fileContent: Buffer.from(mockBytes).toString('base64'), - fileName: undefined, - fileDescriptorMeta: undefined, - filePointerMeta: undefined, - }, - useStreamMode: false, - acceleration: { - xnnpack: undefined, - gpu: undefined, - tflite: {}, - }, - }; - - let fetchSpy: jasmine.Spy; - - beforeEach(() => { - fetchSpy = jasmine.createSpy().and.callFake(async url => { - expect(url).toEqual('foo'); - return { - arrayBuffer: () => mockBytes.buffer, - } as unknown as Response; - }); - global.fetch = fetchSpy; - }); - - it('verifies that at least one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({})) - .toBeRejectedWithError( - /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); - }); - - it('verifies that no more than one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({ - modelAssetPath: `foo`, - modelAssetBuffer: new Uint8Array([]) - })) - .toBeRejectedWithError( - /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); - }); - - it('downloads model', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetPath: `foo`, - }); - - expect(fetchSpy).toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('does not download model when bytes are provided', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - }); - - expect(fetchSpy).not.toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable CPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'CPU', - }); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable GPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - expect(baseOptionsProto.toObject()).toEqual({ - ...mockBytesResult, - acceleration: { - xnnpack: undefined, - gpu: { - useAdvancedGpuApi: false, - api: 0, - allowPrecisionLoss: true, - cachedKernelPath: undefined, - serializedModelDir: undefined, - modelToken: undefined, - usage: 2, - }, - tflite: undefined, - }, - }); - }); - - it('can reset delegate', async () => { - let baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - // Clear backend - baseOptionsProto = - await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); -}); diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts deleted file mode 100644 index 97b62b784..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; -import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; - -// The OSS JS API does not support the builder pattern. -// tslint:disable:jspb-use-builder-pattern - -/** - * Converts a BaseOptions API object to its Protobuf representation. - * @throws If neither a model assset path or buffer is provided - */ -export async function convertBaseOptionsToProto( - updatedOptions: BaseOptions, - currentOptions?: BaseOptionsProto): Promise { - const result = - currentOptions ? currentOptions.clone() : new BaseOptionsProto(); - - await configureExternalFile(updatedOptions, result); - configureAcceleration(updatedOptions, result); - - return result; -} - -/** - * Configues the `externalFile` option and validates that a single model is - * provided. - */ -async function configureExternalFile( - options: BaseOptions, proto: BaseOptionsProto) { - const externalFile = proto.getModelAsset() || new ExternalFile(); - proto.setModelAsset(externalFile); - - if (options.modelAssetPath || options.modelAssetBuffer) { - if (options.modelAssetPath && options.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } - - let modelAssetBuffer = options.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(options.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - externalFile.setFileContent(modelAssetBuffer); - } - - if (!externalFile.hasFileContent()) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); - } -} - -/** Configues the `acceleration` option. */ -function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { - const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'GPU') { - acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); - } else { - acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); - } - proto.setAcceleration(acceleration); -} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 1721661f5..c0d10d28b 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,8 +18,10 @@ mediapipe_ts_library( srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", @@ -53,6 +55,7 @@ mediapipe_ts_library( "task_runner_test.ts", ], deps = [ + ":core", ":task_runner", ":task_runner_test_utils", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 2011fadef..ffb538b52 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,11 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; -import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; +import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; +import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -91,14 +93,52 @@ export abstract class TaskRunner { this.graphRunner.registerModelResourcesGraphService(); } - /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: TaskRunnerOptions): Promise { - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); + /** Configures the task with custom options. */ + abstract setOptions(options: TaskRunnerOptions): Promise; + + /** + * Applies the current set of options, including any base options that have + * not been processed by the task implementation. The options are applied + * synchronously unless a `modelAssetPath` is provided. This ensures that + * for most use cases options are applied directly and immediately affect + * the next inference. + */ + protected applyOptions(options: TaskRunnerOptions): Promise { + const baseOptions: BaseOptions = options.baseOptions || {}; + + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => response.arrayBuffer()) + .then(buffer => { + this.setExternalFile(new Uint8Array(buffer)); + this.refreshGraph(); + }); + } else { + // Apply the setting synchronously. + this.setExternalFile(baseOptions.modelAssetBuffer); + this.refreshGraph(); + return Promise.resolve(); } } + /** Appliest the current options to the MediaPipe graph. */ + protected abstract refreshGraph(): void; + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, @@ -140,6 +180,27 @@ export abstract class TaskRunner { } this.processingErrors = []; } + + /** Configures the `externalFile` option */ + private setExternalFile(modelAssetBuffer?: Uint8Array): void { + const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); + if (modelAssetBuffer) { + externalFile.setFileContent(modelAssetBuffer); + } + this.baseOptions.setModelAsset(externalFile); + } + + /** Configures the `acceleration` option. */ + private setAcceleration(options: BaseOptions) { + const acceleration = + this.baseOptions.getAcceleration() ?? new Acceleration(); + if (options.delegate === 'GPU') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + } else { + acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); + } + this.baseOptions.setAcceleration(acceleration); + } } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index c9aad9d25..a55ac04d7 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -15,18 +15,22 @@ */ import 'jasmine'; +// Placeholder for internal dependency on encodeByteArray import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource URL builder import {GraphRunnerImageLib} from './task_runner'; +import {TaskRunnerOptions} from './task_runner_options.d'; class TaskRunnerFake extends TaskRunner { - protected baseOptions = new BaseOptionsProto(); private errorListener: ErrorListener|undefined; private errors: string[] = []; + baseOptions = new BaseOptionsProto(); + static createFake(): TaskRunnerFake { const wasmModule = createSpyWasmModule(); return new TaskRunnerFake(wasmModule); @@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner { super.finishProcessing(); } + override refreshGraph(): void {} + override setGraph(graphData: Uint8Array, isBinary: boolean): void { super.setGraph(graphData, isBinary); } + setOptions(options: TaskRunnerOptions): Promise { + return this.applyOptions(options); + } + private throwErrors(): void { expect(this.errorListener).toBeDefined(); for (const error of this.errors) { @@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner { } describe('TaskRunner', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + let taskRunner: TaskRunnerFake; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + + taskRunner = TaskRunnerFake.createFake(); + }); + it('handles errors during graph update', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error'); expect(() => { @@ -85,7 +125,6 @@ describe('TaskRunner', () => { }); it('handles errors during graph execution', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.enqueueError('Test error'); @@ -96,7 +135,6 @@ describe('TaskRunner', () => { }); it('can handle multiple errors', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error 1'); taskRunner.enqueueError('Test error 2'); @@ -104,4 +142,106 @@ describe('TaskRunner', () => { taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); }).toThrowError(/Test error 1, Test error 2/); }); + + it('verifies that at least one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({}); + }) + .toThrowError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({ + baseOptions: { + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + } + }); + }) + .toThrowError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('doesn\'t require model once it is configured', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + expect(() => { + taskRunner.setOptions({}); + }).not.toThrowError(); + }); + + it('downloads model', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetPath: `foo`}}); + + expect(fetchSpy).toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('changes model synchronously when bytes are provided', () => { + const resolvedPromise = taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + // Check that the change has been applied even though we do not await the + // above Promise + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + return resolvedPromise; + }); + + it('can enable CPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'CPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + // Clear backend + await taskRunner.setOptions({baseOptions: {delegate: undefined}}); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); }); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 62708700a..981438625 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - override async setOptions(options: TextClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts index 841bf8c48..5578362cb 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -56,7 +56,8 @@ describe('TextClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textClassifier = new TextClassifierFake(); - await textClassifier.setOptions({}); // Initialize graph + await textClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 611233e02..7aa0aa6b9 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - override async setOptions(options: TextEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts index 04a9b371a..2804e4deb 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -56,7 +56,8 @@ describe('TextEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textEmbedder = new TextEmbedderFake(); - await textEmbedder.setOptions({}); // Initialize graph + await textEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e4ea3036f..03958a819 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -29,6 +29,7 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":vision_task_options", ":vision_task_runner", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 6cc9ea328..d77cc4fed 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {VisionTaskOptions} from './vision_task_options'; import {VisionTaskRunner} from './vision_task_runner'; class VisionTaskRunnerFake extends VisionTaskRunner { @@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner { protected override process(): void {} + protected override refreshGraph(): void {} + + override setOptions(options: VisionTaskOptions): Promise { + return this.applyOptions(options); + } + override processImageData(image: ImageSource): void { super.processImageData(image); } @@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - const streamMode = { - modelAsset: undefined, - useStreamMode: true, - acceleration: undefined, - }; - - const imageMode = { - modelAsset: undefined, - useStreamMode: false, - acceleration: undefined, - }; - let visionTaskRunner: VisionTaskRunnerFake; - beforeEach(() => { + beforeEach(async () => { visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { await visionTaskRunner.setOptions({runningMode: 'image'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { @@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => { // Clear running mode await visionTaskRunner.setOptions({runningMode: undefined}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('cannot process images with video mode', async () => { diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 3432b521b..952990326 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ - override async setOptions(options: VisionTaskOptions): Promise { - await super.setOptions(options); + override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; this.baseOptions.setUseStreamMode(useStreamMode); } + return super.applyOptions(options); } /** Sends an image packet to the graph and awaits results. */ diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index b6b795076..cfeb179f5 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -169,9 +169,7 @@ export class GestureRecognizer extends * * @param options The options for the gesture recognizer. */ - override async setOptions(options: GestureRecognizerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: GestureRecognizerOptions): Promise { if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( options.numHands ?? DEFAULT_NUM_HANDS); @@ -221,7 +219,7 @@ export class GestureRecognizer extends ?.clearClassifierOptions(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -342,7 +340,7 @@ export class GestureRecognizer extends } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index c0f0d1554..ff6bba613 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -109,7 +109,8 @@ describe('GestureRecognizer', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); gestureRecognizer = new GestureRecognizerFake(); - await gestureRecognizer.setOptions({}); // Initialize graph + await gestureRecognizer.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 2a0e8286c..24cf9a402 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner { * * @param options The options for the hand landmarker. */ - override async setOptions(options: HandLandmarkerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: HandLandmarkerOptions): Promise { // Configure hand detector options. if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner { options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index fc26680e0..76e77b4bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -98,7 +98,8 @@ describe('HandLandmarker', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); handLandmarker = new HandLandmarkerFake(); - await handLandmarker.setOptions({}); // Initialize graph + await handLandmarker.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 36e7311fb..9298a860c 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner { * * @param options The options for the image classifier. */ - override async setOptions(options: ImageClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts index 2041a0cef..da4a01d02 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -61,7 +61,8 @@ describe('ImageClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageClassifier = new ImageClassifierFake(); - await imageClassifier.setOptions({}); // Initialize graph + await imageClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 0c45ba5e7..cf0bd8c5d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param options The options for the image embedder. */ - override async setOptions(options: ImageEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index cafe0f3d8..b63bb374c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -57,7 +57,8 @@ describe('ImageEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageEmbedder = new ImageEmbedderFake(); - await imageEmbedder.setOptions({}); // Initialize graph + await imageEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index fbfaced12..e4c51de08 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner { * * @param options The options for the object detector. */ - override async setOptions(options: ObjectDetectorOptions): Promise { - await super.setOptions(options); - + override setOptions(options: ObjectDetectorOptions): Promise { // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to // `undefined`. @@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner { this.options.clearCategoryDenylistList(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index fff1a1c48..43b7035d5 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -61,7 +61,8 @@ describe('ObjectDetector', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); objectDetector = new ObjectDetectorFake(); - await objectDetector.setOptions({}); // Initialize graph + await objectDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { From 1924f1cdff94af953c2cd9b01a13d623ea13e7a7 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 28 Dec 2022 14:27:42 -0800 Subject: [PATCH 307/469] Tensor: Fix use_ahwb_ flag and tests on local device involved. PiperOrigin-RevId: 498249332 --- mediapipe/framework/formats/tensor_ahwb.cc | 3 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 16 ++++++-- .../framework/formats/tensor_ahwb_test.cc | 39 ++++--------------- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 466811be7..74b2dca93 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -458,7 +458,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); } } - use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); + // Keep flag value if it was set previously. + use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_); } #else // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index a6ca00949..e2ad869f9 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase { }; TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { } TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { // Request the CPU view to get the memory to be allocated. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { @@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { // Request the GPU view to get the ssbo allocated internally. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; RunInGlContext([&tensor] { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 7ab5a4925..f0baa6303 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -1,34 +1,28 @@ #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/gpu/gpu_test_base.h" #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" -#ifdef MEDIAPIPE_TENSOR_USE_AHWB -#if !MEDIAPIPE_DISABLE_GPU - namespace mediapipe { -class TensorAhwbTest : public mediapipe::GpuTestBase { - public: -}; - -TEST_F(TensorAhwbTest, TestCpuThenAHWB) { +TEST(TensorAhwbTest, TestCpuThenAHWB) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { auto ptr = tensor.GetCpuWriteView().buffer(); EXPECT_NE(ptr, nullptr); } { - auto ahwb = tensor.GetAHardwareBufferReadView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } } -TEST_F(TensorAhwbTest, TestAHWBThenCpu) { +TEST(TensorAhwbTest, TestAHWBThenCpu) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { - auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); } { auto ptr = tensor.GetCpuReadView().buffer(); @@ -36,21 +30,4 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) { } } -TEST_F(TensorAhwbTest, TestCpuThenGl) { - RunInGlContext([] { - Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); - { - auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); - } - { - auto ssbo = tensor.GetOpenGlBufferReadView().name(); - EXPECT_GT(ssbo, 0); - } - }); -} - } // namespace mediapipe - -#endif // !MEDIAPIPE_DISABLE_GPU -#endif // MEDIAPIPE_TENSOR_USE_AHWB From 2d9a969d10bdcac98e0e86f617817e08cf656331 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 28 Dec 2022 16:07:09 -0800 Subject: [PATCH 308/469] Tensor1: memorize size_alignment when tracking the ahwb usage. When CPU/GPU buffer allocated and the tracker selects Ahwb storage to be used then the properly recorded alignment must be used. PiperOrigin-RevId: 498264759 --- mediapipe/framework/formats/BUILD | 2 +- mediapipe/framework/formats/tensor.h | 7 +- mediapipe/framework/formats/tensor_ahwb.cc | 7 +- .../framework/formats/tensor_ahwb_test.cc | 67 +++++++++++++++++++ 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f5a043f10..cce7e5bd0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -455,7 +455,7 @@ cc_library( ], }), deps = [ - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 8a6f02e9d..0f19bb5ee 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,7 +24,7 @@ #include #include -#include "absl/container/flat_hash_set.h" +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" @@ -434,8 +434,9 @@ class Tensor { mutable bool use_ahwb_ = false; mutable uint64_t ahwb_tracking_key_ = 0; // TODO: Tracks all unique tensors. Can grow to a large number. LRU - // can be more predicted. - static inline absl::flat_hash_set ahwb_usage_track_; + // (Least Recently Used) can be more predicted. + // The value contains the size alignment parameter. + static inline absl::flat_hash_map ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 74b2dca93..525f05f31 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { // Mark current tracking key as Ahwb-use. - ahwb_usage_track_.insert(ahwb_tracking_key_); + if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_); + it != ahwb_usage_track_.end()) { + size_alignment = it->second; + } else if (ahwb_tracking_key_ != 0) { + ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment}); + } use_ahwb_ = true; if (__builtin_available(android 26, *)) { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index f0baa6303..3da6ca8d3 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -30,4 +30,71 @@ TEST(TensorAhwbTest, TestAHWBThenCpu) { } } +TEST(TensorAhwbTest, TestAhwbAlignment) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); + { + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 5 = 20, the closest aligned to 16 size is 32. + EXPECT_EQ(desc.width, 32); + } + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } +} + +// Tensor::GetCpuView uses source location mechanism that gives source file name +// and line from where the method is called. The function is intended just to +// have two calls providing the same source file name and line. +auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); } + +// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved +// for the first time then the source location is attached to the tensor. If the +// Ahwb view is requested then from the tensor then the previously recorded Cpu +// view request source location is marked for using Ahwb storage. +// When a Cpu view with the same source location (but for the newly allocated +// tensor) is requested and the location is marked to use Ahwb storage then the +// Ahwb storage is allocated for the CpuView. +TEST(TensorAhwbTest, TestTrackingAhwb) { + // Create first tensor and request Cpu and then Ahwb view to mark the source + // location for Ahwb storage. + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Align size of the Ahwb by multiple of 16. + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + // The second tensor uses the same Cpu view source location so Ahwb + // storage is allocated internally. + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Check the Ahwb size to be aligned to multiple of 16. The alignment is + // stored by previous requesting of the Ahwb view. + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 9 = 36. The closest aligned size is 48. + EXPECT_EQ(desc.width, 48); + } + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } +} + } // namespace mediapipe From aaa16eca1fedf9450689be422ea2dc01c7d74c93 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 29 Dec 2022 08:33:58 -0800 Subject: [PATCH 309/469] Sets the graph service packets before initializing (and validating the graph) in the objc graph wrapper. PiperOrigin-RevId: 498393761 --- mediapipe/objc/MPPGraph.mm | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 1bd177e80..3123eb863 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -230,16 +230,17 @@ if ([wrapper.delegate } - (absl::Status)performStart { - absl::Status status = _graph->Initialize(_config); - if (!status.ok()) { - return status; - } + absl::Status status; for (const auto& service_packet : _servicePackets) { status = _graph->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { return status; } } + status = _graph->Initialize(_config); + if (!status.ok()) { + return status; + } status = _graph->StartRun(_inputSidePackets, _streamHeaders); if (!status.ok()) { return status; From 60c6b155f626f40e2971cda10aa4c3565897874a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 29 Dec 2022 10:16:10 -0800 Subject: [PATCH 310/469] Save an integer id in graph profiler objects to distinguish between different profiler instances during benchmarking. PiperOrigin-RevId: 498409363 --- .../framework/profiler/graph_profiler.cc | 1 + mediapipe/framework/profiler/graph_profiler.h | 9 +++++++ .../framework/profiler/graph_profiler_test.cc | 26 +++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index f14acfc78..6aead5250 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -194,6 +194,7 @@ void GraphProfiler::Initialize( "Calculator \"$0\" has already been added.", node_name); } profile_builder_ = std::make_unique(this); + graph_id_ = ++next_instance_id_; is_initialized_ = true; } diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 23caed4ec..6358cb057 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this { return validated_graph_; } + // Gets a numerical identifier for this GraphProfiler object. + uint64_t GetGraphId() { return graph_id_; } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a @@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this { class GraphProfileBuilder; std::unique_ptr profile_builder_; + // The globally incrementing identifier for all graphs in a process. + static inline std::atomic_int next_instance_id_ = 0; + + // A unique identifier for this object. Only unique within a process. + uint64_t graph_id_; + // For testing. friend GraphProfilerTestPeer; }; diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 81ba90cda..75d1c7ebd 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) { "Cannot initialize .* multiple times."); } +// Tests that graph identifiers are not reused, even after destruction. +TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) { + auto raw_graph_config = R"( + profiler_config { + enable_profiler: true + } + input_stream: "input_stream" + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + })"; + const int n_iterations = 100; + absl::flat_hash_set seen_ids; + for (int i = 0; i < n_iterations; ++i) { + std::shared_ptr profiler = + std::make_shared(); + auto graph_config = CreateGraphConfig(raw_graph_config); + mediapipe::ValidatedGraphConfig validated_graph; + QCHECK_OK(validated_graph.Initialize(graph_config)); + profiler->Initialize(validated_graph); + + int id = profiler->GetGraphId(); + ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id))); + seen_ids.insert(id); + } +} // Tests that Pause(), Resume(), and Reset() works. TEST_F(GraphProfilerTestPeer, PauseResumeReset) { InitializeProfilerWithGraphConfig(R"( From 9252a025e5604cb61b11cbf23943dc7fb9e6f679 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 30 Dec 2022 04:56:57 -0800 Subject: [PATCH 311/469] Use custom gesture options in GestureRecognizer PiperOrigin-RevId: 498567432 --- .../tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 01f444742..91a5ec213 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto custom_gestures_classifier_options_proto = std::make_unique( components::processors::ConvertClassifierOptionsToProto( - &(options->canned_gestures_classifier_options))); + &(options->custom_gestures_classifier_options))); hand_gesture_recognizer_graph_options ->mutable_custom_gesture_classifier_graph_options() ->mutable_classifier_options() - ->Swap(canned_gestures_classifier_options_proto.get()); + ->Swap(custom_gestures_classifier_options_proto.get()); return options_proto; } From 2f4bb5d545fbd6b6389248b7123635dcdfff02b7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 3 Jan 2023 09:34:21 -0800 Subject: [PATCH 312/469] Use utility framebuffer in ViewDoneWritingSimulatorWorkaround This code needs a FBO to bind the texture. Fixes invalid results when running under simulator. PiperOrigin-RevId: 499241867 --- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 75 +++++++++++-------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index 014cc1c69..7cac32b7f 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -74,42 +74,51 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, const GlTextureView& view) { CHECK(pixel_buffer); - CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferLockBaseAddress failed: " << err; - OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); - size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); - uint8_t* pixel_ptr = - static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); - if (pixel_format == kCVPixelFormatType_32BGRA) { - // TODO: restore previous framebuffer? Move this to helper so we - // can use BindFramebuffer? - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); + auto ctx = GlContext::GetCurrent().get(); + if (!ctx) ctx = view.gl_context(); + ctx->Run([pixel_buffer, &view, ctx] { + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << err; + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); + uint8_t* pixel_ptr = + static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); + if (pixel_format == kCVPixelFormatType_32BGRA) { + glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx)); + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), view.name(), 0); - size_t contiguous_bytes_per_row = view.width() * 4; - if (bytes_per_row == contiguous_bytes_per_row) { - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - pixel_ptr); - } else { - std::vector contiguous_buffer(contiguous_bytes_per_row * - view.height()); - uint8_t* temp_ptr = contiguous_buffer.data(); - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - temp_ptr); - for (int i = 0; i < view.height(); ++i) { - memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); - temp_ptr += contiguous_bytes_per_row; - pixel_ptr += bytes_per_row; + size_t contiguous_bytes_per_row = view.width() * 4; + if (bytes_per_row == contiguous_bytes_per_row) { + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, pixel_ptr); + } else { + // TODO: use GL_PACK settings for row length. We can expect + // GLES 3.0 on iOS now. + std::vector contiguous_buffer(contiguous_bytes_per_row * + view.height()); + uint8_t* temp_ptr = contiguous_buffer.data(); + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, temp_ptr); + for (int i = 0; i < view.height(); ++i) { + memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); + temp_ptr += contiguous_bytes_per_row; + pixel_ptr += bytes_per_row; + } } + // TODO: restore previous framebuffer? + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), 0, 0); + glBindFramebuffer(GL_FRAMEBUFFER, 0); + } else { + LOG(ERROR) << "unsupported pixel format: " << pixel_format; } - } else { - LOG(ERROR) << "unsupported pixel format: " << pixel_format; - } - err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferUnlockBaseAddress failed: " << err; + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << err; + }); } #endif // TARGET_IPHONE_SIMULATOR From f53c0eaceeae9b7cb622764d78054f8e44222ba3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 09:38:02 -0800 Subject: [PATCH 313/469] Extend tag conversion behavior to also convert `:` (in addition to the current `/`, `-`, and `.`) to `_`. PiperOrigin-RevId: 499243005 --- .../tensorflow_session_from_saved_model_calculator.cc | 7 +++---- .../tensorflow_session_from_saved_model_calculator.proto | 4 ++-- .../tensorflow_session_from_saved_model_generator.cc | 7 +++---- .../tensorflow_session_from_saved_model_generator.proto | 4 ++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 922eb9d50..18bddbbe3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, . and :'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index 927d3b51f..515b46fa9 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase as well as switch + // /, -, .and :'s to _'s, which enables common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index d5236f1cc..ee69ec56a 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, and .'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index d24a1cd73..d45fcb662 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase, as well as switch /'s + // -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., From 987f4dc1ed89801e54c408abd670f63ce0c77007 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 10:52:41 -0800 Subject: [PATCH 314/469] Make addJsamineCustomFloatEqualityTest configurable PiperOrigin-RevId: 499263931 --- mediapipe/tasks/web/core/task_runner_test_utils.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 2a1161a55..838b3f585 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule { * Sets up our equality testing to use a custom float equality checking function * to avoid incorrect test results due to minor floating point inaccuracies. */ -export function addJasmineCustomFloatEqualityTester() { +export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) { jasmine.addCustomEqualityTester((a, b) => { // Custom float equality if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { - return Math.abs(a - b) < 5e-8; + return Math.abs(a - b) < tolerance; } return; }); From 68f247a5c7a2f081e6f0ff8b25b9187de5646e2b Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 12:03:57 -0800 Subject: [PATCH 315/469] Internal change PiperOrigin-RevId: 499282085 --- .../web/vision/hand_landmarker/hand_landmarker_result.d.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 89f867d69..8a6d9bfa6 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Landmark, NormalizedLandmark, Category}; + /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ From 75b87e0e321090bf73653d83ebfa69cf6f73621f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 12:09:59 -0800 Subject: [PATCH 316/469] Internal change PiperOrigin-RevId: 499283559 --- .../gesture_recognizer/gesture_recognizer.ts | 35 ++++++++++++++----- .../gesture_recognizer_result.d.ts | 8 ++++- .../gesture_recognizer_test.ts | 23 +++++++++++- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index cfeb179f5..c77f2c67a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -263,12 +263,22 @@ export class GestureRecognizer extends NORM_RECT_STREAM, timestamp); this.finishProcessing(); - return { - gestures: this.gestures, - landmarks: this.landmarks, - worldLandmarks: this.worldLandmarks, - handednesses: this.handednesses - }; + if (this.gestures.length === 0) { + // If no gestures are detected in the image, just return an empty list + return { + gestures: [], + landmarks: [], + worldLandmarks: [], + handednesses: [], + }; + } else { + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } } /** Sets the default values for the graph. */ @@ -283,15 +293,19 @@ export class GestureRecognizer extends } /** Converts the proto data to a Category[][] structure. */ - private toJsCategories(data: Uint8Array[]): Category[][] { + private toJsCategories(data: Uint8Array[], populateIndex = true): + Category[][] { const result: Category[][] = []; for (const binaryProto of data) { const inputList = ClassificationList.deserializeBinary(binaryProto); const outputList: Category[] = []; for (const classification of inputList.getClassificationList()) { + const index = populateIndex && classification.hasIndex() ? + classification.getIndex()! : + DEFAULT_CATEGORY_INDEX; outputList.push({ score: classification.getScore() ?? 0, - index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + index, categoryName: classification.getLabel() ?? '', displayName: classification.getDisplayName() ?? '', }); @@ -375,7 +389,10 @@ export class GestureRecognizer extends }); this.graphRunner.attachProtoVectorListener( HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); + // Gesture index is not used, because the final gesture result comes + // from multiple classifiers. + this.gestures.push( + ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); }); this.graphRunner.attachProtoVectorListener( HANDEDNESS_STREAM, binaryProto => { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index e570270b2..323290008 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Category, Landmark, NormalizedLandmark}; + /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ @@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult { /** Handedness of detected hands. */ handednesses: Category[][]; - /** Recognized hand gestures of detected hands */ + /** + * Recognized hand gestures of detected hands. Note that the index of the + * gesture is always -1, because the raw indices from multiple gesture + * classifiers cannot consolidate to a meaningful index. + */ gestures: Category[][]; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index ff6bba613..ee51fd32a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -272,7 +272,7 @@ describe('GestureRecognizer', () => { expect(gestures).toEqual({ 'gestures': [[{ 'score': 0.2, - 'index': 2, + 'index': -1, 'categoryName': 'gesture_label', 'displayName': 'gesture_display_name' }]], @@ -305,4 +305,25 @@ describe('GestureRecognizer', () => { // gestures. expect(gestures2).toEqual(gestures1); }); + + it('returns empty results when no gestures are detected', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!([]); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestures).toEqual({ + 'gestures': [], + 'landmarks': [], + 'worldLandmarks': [], + 'handednesses': [] + }); + }); }); From e7dc989f715382c10ac6d714f4f4be5d330f903d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 14:12:34 -0800 Subject: [PATCH 317/469] Internal Change PiperOrigin-RevId: 499313491 --- mediapipe/examples/desktop/autoflip/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 562f11c49..0e28746dc 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -30,6 +30,10 @@ proto_library( java_lite_proto_library( name = "autoflip_messages_java_proto_lite", + visibility = [ + "//java/com/google/android/apps/photos:__subpackages__", + "//javatests/com/google/android/apps/photos:__subpackages__", + ], deps = [ ":autoflip_messages_proto", ], From add5600d0d4e9f0213ebf58088301dc7e743194a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 17:18:59 -0800 Subject: [PATCH 318/469] Internal change PiperOrigin-RevId: 499351795 --- .../python/text/text_classifier/text_classifier_test.py | 1 + .../python/vision/image_classifier/image_classifier_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index d2edb78bc..eb4443b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -71,6 +71,7 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( filecmp.cmp( output_metadata_file, diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 14c67d831..afda8643b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,6 +135,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( filecmp.cmp( output_metadata_file, expected_metadata_file, shallow=False)) From a4ea606eac3adf3ca5e149e9e6ff6573168971a6 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 4 Jan 2023 08:21:55 -0800 Subject: [PATCH 319/469] Internal change. PiperOrigin-RevId: 499490514 --- .../framework/formats/tensor_ahwb_gpu_test.cc | 28 +++++++++---------- .../framework/formats/tensor_ahwb_test.cc | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index e2ad869f9..45d341e20 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -83,8 +83,8 @@ void FillGpuBuffer(GLuint name, std::size_t size, TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); - MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); - MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); } class TensorAhwbGpuTest : public mediapipe::GpuTestBase { @@ -97,18 +97,18 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { { // Request Ahwb first to get Ahwb storage allocated internally. auto view = tensor.GetAHardwareBufferWriteView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetWritingFinishedFD(-1, [](bool) { return true; }); } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -124,18 +124,18 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { { // Request Ahwb first to get Ahwb storage allocated internally. auto view = tensor.GetAHardwareBufferWriteView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -153,18 +153,18 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); for (int i = 0; i < num_elements; i++) { ptr[i] = static_cast(i) / 10.0f; } } { auto view = tensor.GetAHardwareBufferReadView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -182,17 +182,17 @@ TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); { auto view = tensor.GetAHardwareBufferReadView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 3da6ca8d3..69e49dd58 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -34,7 +34,7 @@ TEST(TensorAhwbTest, TestAhwbAlignment) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); { auto view = tensor.GetAHardwareBufferWriteView(16); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); if (__builtin_available(android 26, *)) { AHardwareBuffer_Desc desc; AHardwareBuffer_describe(view.handle(), &desc); From 9a70af146432dcbbbc961f9c1a5af4a039d0909a Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 4 Jan 2023 08:52:03 -0800 Subject: [PATCH 320/469] Internal change. PiperOrigin-RevId: 499496793 --- mediapipe/framework/formats/tensor_ahwb_gpu_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index 45d341e20..ff78d1f88 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -68,8 +68,9 @@ void FillGpuBuffer(GLuint name, std::size_t size, MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, &max_length)); std::vector error_log(max_length); - glGetShaderInfoLog(shader, max_length, &max_length, error_log.data()); - glDeleteShader(shader); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderInfoLog, shader, max_length, + &max_length, error_log.data())); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); FAIL() << error_log.data(); return; } From e3131d7d7856771def3c1c141720ca311ed0f3d9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 10:31:04 -0800 Subject: [PATCH 321/469] Internal change PiperOrigin-RevId: 499521620 --- mediapipe/model_maker/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index ea193db94..7114e2080 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -132,9 +132,9 @@ setuptools.setup( 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', From 24cc0672c47b0b2fac28bbc8434e93a9fccb47ad Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 10:57:33 -0800 Subject: [PATCH 322/469] Internal change PiperOrigin-RevId: 499529022 --- mediapipe/examples/desktop/autoflip/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 0e28746dc..340205caa 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -18,6 +18,8 @@ licenses(["notice"]) package(default_visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", + "//photos/editing/mobile/mediapipe/proto:__subpackages__", ]) proto_library( @@ -45,6 +47,8 @@ mediapipe_cc_proto_library( cc_deps = ["//mediapipe/framework:calculator_cc_proto"], visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__pkg__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", ], deps = [":autoflip_messages_proto"], ) From 43bf02443c1b8b7f237c9f7ef408da5cb56619b8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 17:31:48 -0800 Subject: [PATCH 323/469] Option to remove overlapping values computed for different timestamps. PiperOrigin-RevId: 499635143 --- .../tensor_to_vector_int_calculator.cc | 20 +++++++ ...sor_to_vector_int_calculator_options.proto | 4 ++ .../tensor_to_vector_int_calculator_test.cc | 53 ++++++++++++++++++- 3 files changed, 76 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc index 2f4ff28cf..f92ddf08d 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc @@ -37,8 +37,10 @@ class TensorToVectorIntCalculator : public CalculatorBase { private: void TokenizeVector(std::vector* vector) const; + void RemoveOverlapVector(std::vector* vector); TensorToVectorIntCalculatorOptions options_; + int32_t overlapping_values_; }; REGISTER_CALCULATOR(TensorToVectorIntCalculator); @@ -66,6 +68,7 @@ absl::Status TensorToVectorIntCalculator::GetContract(CalculatorContract* cc) { absl::Status TensorToVectorIntCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); + overlapping_values_ = 0; // Inform mediapipe that this calculator produces an output at time t for // each input received at time t (i.e. this calculator does not buffer @@ -106,6 +109,7 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(&instance_output); + RemoveOverlapVector(&instance_output); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } else { @@ -128,12 +132,28 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(output.get()); + RemoveOverlapVector(output.get()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } return absl::OkStatus(); } +void TensorToVectorIntCalculator::RemoveOverlapVector( + std::vector* vector) { + if (options_.overlap() <= 0) { + return; + } + if (overlapping_values_ > 0) { + if (vector->size() < overlapping_values_) { + vector->clear(); + } else { + vector->erase(vector->begin(), vector->begin() + overlapping_values_); + } + } + overlapping_values_ = options_.overlap(); +} + void TensorToVectorIntCalculator::TokenizeVector( std::vector* vector) const { if (!options_.tensor_is_token()) { diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto index 9da3298b9..76b9be952 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto @@ -36,4 +36,8 @@ message TensorToVectorIntCalculatorOptions { optional bool tensor_is_token = 3 [default = false]; // Threshold for the token generation. optional float token_threshold = 4 [default = 0.5]; + + // Values which overlap between timely following vectors. They are removed + // from the output to reduce redundancy. + optional int32 overlap = 5 [default = 0]; } diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc index 60c0d47ec..406c2c1a7 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc @@ -28,7 +28,8 @@ namespace tf = ::tensorflow; class TensorToVectorIntCalculatorTest : public ::testing::Test { protected: void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd, - const bool tensor_is_token = false) { + const bool tensor_is_token = false, + const int32_t overlap = 0) { CalculatorGraphConfig::Node config; config.set_calculator("TensorToVectorIntCalculator"); config.add_input_stream("input_tensor"); @@ -38,6 +39,7 @@ class TensorToVectorIntCalculatorTest : public ::testing::Test { options->set_tensor_is_2d(tensor_is_2d); options->set_flatten_nd(flatten_nd); options->set_tensor_is_token(tensor_is_token); + options->set_overlap(overlap); runner_ = absl::make_unique(config); } @@ -188,5 +190,54 @@ TEST_F(TensorToVectorIntCalculatorTest, FlattenShouldTakeAllDimensions) { } } +TEST_F(TensorToVectorIntCalculatorTest, Overlap) { + SetUpRunner(false, false, false, 2); + for (int time = 0; time < 3; ++time) { + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_INT64, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is + // small. + tensor_vec(i) = static_cast(time + (1 << i)); + } + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + } + + ASSERT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(3, output_packets.size()); + + { + // First vector in full. + int time = 0; + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const int64 expected = static_cast(time + (1 << i)); + EXPECT_EQ(expected, output_vector[i]); + } + } + + // All following vectors the overlap removed + for (int time = 1; time < 3; ++time) { + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(3, output_vector.size()); + for (int i = 0; i < 3; ++i) { + const int64 expected = static_cast(time + (1 << (i + 2))); + EXPECT_EQ(expected, output_vector[i]); + } + } +} + } // namespace } // namespace mediapipe From 463cbb60eea6af436bbec6d13fceae0f65cdbe64 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 5 Jan 2023 07:55:57 -0800 Subject: [PATCH 324/469] Fix RGBA vs RGB selection when creating GLTexture. PiperOrigin-RevId: 499877590 --- .../calculators/tensor/image_to_tensor_converter_gl_buffer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index a551e7f8d..eb1726aac 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -285,7 +285,7 @@ class GlProcessor : public ImageToTensorConverter { auto source_texture = gl_helper_.CreateSourceTexture(input); tflite::gpu::gl::GlTexture input_texture( GL_TEXTURE_2D, source_texture.name(), - input_num_channels == 4 ? GL_RGB : GL_RGBA, + input_num_channels == 4 ? GL_RGBA : GL_RGB, source_texture.width() * source_texture.height() * input_num_channels * sizeof(uint8_t), /*layer=*/0, From 35293d88bcb35b87162fbbb40b76226677f98d3f Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 5 Jan 2023 08:54:25 -0800 Subject: [PATCH 325/469] Tensor: move into tensor sub-directory. PiperOrigin-RevId: 499896489 --- mediapipe/framework/formats/BUILD | 2 +- mediapipe/framework/formats/tensor.h | 2 +- mediapipe/framework/formats/tensor/BUILD | 24 +++++++++++++++++++ .../{tensor_internal.h => tensor/internal.h} | 0 .../framework/formats/tensor_ahwb_gpu_test.cc | 2 +- 5 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 mediapipe/framework/formats/tensor/BUILD rename mediapipe/framework/formats/{tensor_internal.h => tensor/internal.h} (100%) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index cce7e5bd0..371f23ed1 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -430,7 +430,7 @@ cc_library( ], hdrs = [ "tensor.h", - "tensor_internal.h", + "//mediapipe/framework/formats/tensor:internal.h", ], copts = select({ "//mediapipe:apple": [ diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 0f19bb5ee..4a952ae09 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -26,7 +26,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" -#include "mediapipe/framework/formats/tensor_internal.h" +#include "mediapipe/framework/formats/tensor/internal.h" #include "mediapipe/framework/port.h" #if MEDIAPIPE_METAL_ENABLED diff --git a/mediapipe/framework/formats/tensor/BUILD b/mediapipe/framework/formats/tensor/BUILD new file mode 100644 index 000000000..c634b0dda --- /dev/null +++ b/mediapipe/framework/formats/tensor/BUILD @@ -0,0 +1,24 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:public"], + features = ["-layering_check"], +) + +licenses(["notice"]) + +exports_files([ + "internal.h", +]) diff --git a/mediapipe/framework/formats/tensor_internal.h b/mediapipe/framework/formats/tensor/internal.h similarity index 100% rename from mediapipe/framework/formats/tensor_internal.h rename to mediapipe/framework/formats/tensor/internal.h diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index ff78d1f88..b06bd3ef2 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -7,7 +7,7 @@ #include #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/framework/formats/tensor_data_types.h" +#include "mediapipe/framework/formats/tensor/views/data_types.h" #include "mediapipe/gpu/gpu_test_base.h" #include "mediapipe/gpu/shader_util.h" #include "tensorflow/lite/delegates/gpu/gl/gl_call.h" From 81a46bb31a5da15a0ddd7123b92499c6ca14dc86 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 5 Jan 2023 09:12:06 -0800 Subject: [PATCH 326/469] Internal change PiperOrigin-RevId: 499902323 --- mediapipe/web/graph_runner/graph_runner.ts | 73 ++++++++++++++-------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index ef866bc91..644d74918 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -73,10 +73,11 @@ export declare interface WasmModule { // Wasm Module output listener entrypoints. Also built as part of // gl_graph_runner_internal_multi_input. - simpleListeners?: {[outputStreamName: string]: (data: unknown) => void}; + simpleListeners?: + {[outputStreamName: string]: (data: unknown, timestamp: number) => void}; vectorListeners?: { [outputStreamName: string]: ( - data: unknown, index: number, length: number) => void + data: unknown, index: number, length: number, timestamp: number) => void }; _attachBoolListener: (streamNamePtr: number) => void; _attachBoolVectorListener: (streamNamePtr: number) => void; @@ -418,10 +419,12 @@ export class GraphRunner { * Ensures existence of the simple listeners table and registers the callback. * Intended for internal usage. */ - setListener(outputStreamName: string, callbackFcn: (data: T) => void) { + setListener( + outputStreamName: string, + callbackFcn: (data: T, timestamp: number) => void) { this.wasmModule.simpleListeners = this.wasmModule.simpleListeners || {}; this.wasmModule.simpleListeners[outputStreamName] = - callbackFcn as (data: unknown) => void; + callbackFcn as (data: unknown, timestamp: number) => void; } /** @@ -429,11 +432,12 @@ export class GraphRunner { * Intended for internal usage. */ setVectorListener( - outputStreamName: string, callbackFcn: (data: T[]) => void) { + outputStreamName: string, + callbackFcn: (data: T[], timestamp: number) => void) { let buffer: T[] = []; this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; this.wasmModule.vectorListeners[outputStreamName] = - (data: unknown, index: number, length: number) => { + (data: unknown, index: number, length: number, timestamp: number) => { // The Wasm listener gets invoked once for each element. Once we // receive all elements, we invoke the registered callback with the // full array. @@ -442,7 +446,7 @@ export class GraphRunner { // Invoke the user callback directly, as the Wasm layer may clean up // the underlying data elements once we leave the scope of the // listener. - callbackFcn(buffer); + callbackFcn(buffer, timestamp); buffer = []; } }; @@ -740,7 +744,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachBoolListener( - outputStreamName: string, callbackFcn: (data: boolean) => void): void { + outputStreamName: string, + callbackFcn: (data: boolean, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -760,7 +765,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachBoolVectorListener( - outputStreamName: string, callbackFcn: (data: boolean[]) => void): void { + outputStreamName: string, + callbackFcn: (data: boolean[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -780,7 +786,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachIntListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -800,7 +807,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachIntVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -820,7 +828,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachDoubleListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -840,7 +849,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachDoubleVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -860,7 +870,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachFloatListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -880,7 +891,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachFloatVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -900,7 +912,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachStringListener( - outputStreamName: string, callbackFcn: (data: string) => void): void { + outputStreamName: string, + callbackFcn: (data: string, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -920,7 +933,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachStringVectorListener( - outputStreamName: string, callbackFcn: (data: string[]) => void): void { + outputStreamName: string, + callbackFcn: (data: string[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -950,7 +964,8 @@ export class GraphRunner { * with it). */ attachProtoListener( - outputStreamName: string, callbackFcn: (data: Uint8Array) => void, + outputStreamName: string, + callbackFcn: (data: Uint8Array, timestamp: number) => void, makeDeepCopy?: boolean): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -984,7 +999,8 @@ export class GraphRunner { * with it). */ attachProtoVectorListener( - outputStreamName: string, callbackFcn: (data: Uint8Array[]) => void, + outputStreamName: string, + callbackFcn: (data: Uint8Array[], timestamp: number) => void, makeDeepCopy?: boolean): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -1017,8 +1033,10 @@ export class GraphRunner { * up automatically by JS garbage collection whenever the user is finished * with it). */ - attachAudioListener(outputStreamName: string, - callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void { + attachAudioListener( + outputStreamName: string, + callbackFcn: (data: Float32Array, timestamp: number) => void, + makeDeepCopy?: boolean): void { if (!this.wasmModule._attachAudioListener) { console.warn( 'Attempting to use attachAudioListener without support for ' + @@ -1027,12 +1045,13 @@ export class GraphRunner { // Set up our TS listener to receive any packets for this stream, and // additionally reformat our Uint8Array into a Float32Array for the user. - this.setListener(outputStreamName, (data: Uint8Array) => { - // Should be very fast - const floatArray = - new Float32Array(data.buffer, data.byteOffset, data.length / 4); - callbackFcn(floatArray); - }); + this.setListener( + outputStreamName, (data: Uint8Array, timestamp: number) => { + // Should be very fast + const floatArray = + new Float32Array(data.buffer, data.byteOffset, data.length / 4); + callbackFcn(floatArray, timestamp); + }); // Tell our graph to listen for string packets on this stream. this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { From 667fd81ddc12be213c0091c73f4c71fe0e4e35b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 5 Jan 2023 11:40:59 -0800 Subject: [PATCH 327/469] Internal change PiperOrigin-RevId: 499956657 --- .../audio_classifier/audio_classifier_test.ts | 33 ++++++++++--------- .../audio_embedder/audio_embedder_test.ts | 11 ++++--- .../text_classifier/text_classifier_test.ts | 18 +++++----- .../text/text_embedder/text_embedder_test.ts | 7 ++-- .../gesture_recognizer_test.ts | 32 ++++++++++-------- .../hand_landmarker/hand_landmarker_test.ts | 14 ++++---- .../image_classifier/image_classifier_test.ts | 6 ++-- .../image_embedder/image_embedder_test.ts | 5 +-- .../object_detector/object_detector_test.ts | 5 +-- 9 files changed, 75 insertions(+), 56 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts index 2089f184f..b7bb158de 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -34,7 +34,8 @@ class AudioClassifierFake extends AudioClassifier implements attachListenerSpies: jasmine.Spy[] = []; graph: CalculatorGraphConfig|undefined; - private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + private protoVectorListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; private resultProtoVector: ClassificationResult[] = []; constructor() { @@ -59,8 +60,10 @@ class AudioClassifierFake extends AudioClassifier implements }); spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => { if (!this.protoVectorListener) return; - this.protoVectorListener(this.resultProtoVector.map( - classificationResult => classificationResult.serializeBinary())); + this.protoVectorListener( + this.resultProtoVector.map( + classificationResult => classificationResult.serializeBinary()), + 1337); }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -138,12 +141,12 @@ describe('AudioClassifier', () => { classifcations.setHeadIndex(1); classifcations.setHeadName('headName'); let classificationList = new ClassificationList(); - let clasification = new Classification(); - clasification.setIndex(1); - clasification.setScore(0.2); - clasification.setDisplayName('displayName'); - clasification.setLabel('categoryName'); - classificationList.addClassification(clasification); + let classification = new Classification(); + classification.setIndex(1); + classification.setScore(0.2); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); resultProtoVector.push(classificationResult); @@ -152,10 +155,10 @@ describe('AudioClassifier', () => { classificationResult.setTimestampMs(1); classifcations = new Classifications(); classificationList = new ClassificationList(); - clasification = new Classification(); - clasification.setIndex(2); - clasification.setScore(0.3); - classificationList.addClassification(clasification); + classification = new Classification(); + classification.setIndex(2); + classification.setScore(0.3); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); resultProtoVector.push(classificationResult); @@ -191,8 +194,8 @@ describe('AudioClassifier', () => { const classificationResult = new ClassificationResult(); const classifcations = new Classifications(); const classificationList = new ClassificationList(); - const clasification = new Classification(); - classificationList.addClassification(clasification); + const classification = new Classification(); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index dde61a6e9..a8a2b232b 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -34,8 +34,10 @@ class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake { attachListenerSpies: jasmine.Spy[] = []; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProto: Uint8Array) => void)|undefined; - protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; + protoVectorListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -163,7 +165,7 @@ describe('AudioEmbedder', () => { audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(audioEmbedder); // Pass the test data to our listener - audioEmbedder.protoListener!(resultProto.serializeBinary()); + audioEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); // Invoke the audio embedder @@ -175,7 +177,8 @@ describe('AudioEmbedder', () => { audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(audioEmbedder); // Pass the test data to our listener - audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]); + audioEmbedder.protoVectorListener! + ([resultProto.serializeBinary()], 1337); }); // Invoke the audio embedder diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts index 5578362cb..d9eb14865 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -32,7 +32,8 @@ class TextClassifierFake extends TextClassifier implements MediapipeTasksFake { attachListenerSpies: jasmine.Spy[] = []; graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -118,19 +119,20 @@ describe('TextClassifier', () => { classifcations.setHeadIndex(1); classifcations.setHeadName('headName'); const classificationList = new ClassificationList(); - const clasification = new Classification(); - clasification.setIndex(1); - clasification.setScore(0.2); - clasification.setDisplayName('displayName'); - clasification.setLabel('categoryName'); - classificationList.addClassification(clasification); + const classification = new Classification(); + classification.setIndex(1); + classification.setScore(0.2); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); // Pass the test data to our listener textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(textClassifier); - textClassifier.protoListener!(classificationResult.serializeBinary()); + textClassifier.protoListener! + (classificationResult.serializeBinary(), 1337); }); // Invoke the text classifier diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts index 2804e4deb..e26b85bf4 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -31,7 +31,8 @@ class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; attachListenerSpies: jasmine.Spy[] = []; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + protoListener: + ((binaryProtos: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -120,7 +121,7 @@ describe('TextEmbedder', () => { // Pass the test data to our listener textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(textEmbedder); - textEmbedder.protoListener!(resultProto.serializeBinary()); + textEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); // Invoke the text embedder @@ -149,7 +150,7 @@ describe('TextEmbedder', () => { // Pass the test data to our listener textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(textEmbedder); - textEmbedder.protoListener!(resultProto.serializeBinary()); + textEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); // Invoke the text embedder diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index ee51fd32a..3611c3a7d 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -26,7 +26,7 @@ import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer' // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -type ProtoListener = ((binaryProtos: Uint8Array[]) => void); +type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); function createHandednesses(): Uint8Array[] { const handsProto = new ClassificationList(); @@ -254,11 +254,13 @@ describe('GestureRecognizer', () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(gestureRecognizer); - gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); gestureRecognizer.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - gestureRecognizer.listeners.get('handedness')!(createHandednesses()); - gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337); }); // Invoke the gesture recognizer @@ -290,11 +292,13 @@ describe('GestureRecognizer', () => { it('clears results between invoations', async () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { - gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); gestureRecognizer.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - gestureRecognizer.listeners.get('handedness')!(createHandednesses()); - gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337); }); // Invoke the gesture recognizer twice @@ -310,11 +314,13 @@ describe('GestureRecognizer', () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(gestureRecognizer); - gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); gestureRecognizer.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - gestureRecognizer.listeners.get('handedness')!(createHandednesses()); - gestureRecognizer.listeners.get('hand_gestures')!([]); + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!([], 1337); }); // Invoke the gesture recognizer diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 76e77b4bf..1a813c6f7 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -27,7 +27,7 @@ import {HandLandmarkerOptions} from './hand_landmarker_options'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -type ProtoListener = ((binaryProtos: Uint8Array[]) => void); +type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); function createHandednesses(): Uint8Array[] { const handsProto = new ClassificationList(); @@ -206,10 +206,10 @@ describe('HandLandmarker', () => { // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(handLandmarker); - handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); handLandmarker.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - handLandmarker.listeners.get('handedness')!(createHandednesses()); + (createWorldLandmarks(), 1337); + handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); }); // Invoke the hand landmarker @@ -235,10 +235,10 @@ describe('HandLandmarker', () => { it('clears results between invoations', async () => { // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { - handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); handLandmarker.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - handLandmarker.listeners.get('handedness')!(createHandednesses()); + (createWorldLandmarks(), 1337); + handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); }); // Invoke the hand landmarker twice diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts index da4a01d02..60595310e 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -35,7 +35,8 @@ class ImageClassifierFake extends ImageClassifier implements graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -128,7 +129,8 @@ describe('ImageClassifier', () => { // Pass the test data to our listener imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(imageClassifier); - imageClassifier.protoListener!(classificationResult.serializeBinary()); + imageClassifier.protoListener! + (classificationResult.serializeBinary(), 1337); }); // Invoke the image classifier diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index b63bb374c..01ec751e3 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -31,7 +31,8 @@ class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; attachListenerSpies: jasmine.Spy[] = []; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + protoListener: + ((binaryProtos: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -125,7 +126,7 @@ describe('ImageEmbedder', () => { // Pass the test data to our listener imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(imageEmbedder); - imageEmbedder.protoListener!(resultProto.serializeBinary()); + imageEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index 43b7035d5..5bfb74ab6 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -35,7 +35,8 @@ class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + protoListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -200,7 +201,7 @@ describe('ObjectDetector', () => { // Pass the test data to our listener objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(objectDetector); - objectDetector.protoListener!(detectionProtos); + objectDetector.protoListener!(detectionProtos, 1337); }); // Invoke the object detector From 33df6c042fc3d78f525a6f7c86b10d67d091ddf1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:07:11 +0530 Subject: [PATCH 328/469] Added iOS result containers for classification tasks --- .../tasks/ios/components/containers/BUILD | 32 +++++ .../containers/sources/MPPCategory.h | 68 ++++++++++ .../containers/sources/MPPCategory.m | 33 +++++ .../sources/MPPClassificationResult.h | 116 ++++++++++++++++++ .../sources/MPPClassificationResult.m | 51 ++++++++ 5 files changed, 300 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/BUILD create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPCategory.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPCategory.m create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD new file mode 100644 index 000000000..9d82fc55a --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -0,0 +1,32 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategory", + srcs = ["sources/MPPCategory.m"], + hdrs = ["sources/MPPCategory.h"], +) + +objc_library( + name = "MPPClassificationResult", + srcs = ["sources/MPPClassificationResult.m"], + hdrs = ["sources/MPPClassificationResult.h"], + deps = [ + ":MPPCategory", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h new file mode 100644 index 000000000..648725d95 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -0,0 +1,68 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Category is a util class that contains a label, its display name, a float value as score, and the + * index of the label in the corresponding label file. Typically it's used as the result of + * classification tasks. + **/ +NS_SWIFT_NAME(ClassificationCategory) +@interface MPPCategory : NSObject + +/** + * The index of the label in the corresponding label file. Set to -1 if the index is + * not set. + **/ +@property(nonatomic, readonly) NSInteger index; + +/** Confidence score for this class . **/ +@property(nonatomic, readonly) float score; + +/** The label of this category object. **/ +@property(nonatomic, readonly, nullable) NSString *categoryName; + +/** + * The display name of the label, which may be translated for different locales. For example, a + * label, "apple", may be translated into Spanish for display purpose, so that the display name is + * "manzana". + **/ +@property(nonatomic, readonly, nullable) NSString *displayName; + +/** + * Initializes a new `MPPCategory` with the given index, score, category name and display name. + * + * @param index The index of the label in the corresponding label file. + * @param score The probability score of this label category. + * @param categoryName The label of this category object. + * @param displayName The display name of the label. + * + * @return An instance of `MPPCategory` initialized with the given index, score, category name and + * display name. + **/ +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + categoryName:(nullable NSString *)categoryName + displayName:(nullable NSString *)displayName NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m new file mode 100644 index 000000000..824fae65e --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -0,0 +1,33 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +@implementation MPPCategory + +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + categoryName:(nullable NSString *)categoryName + displayName:(nullable NSString *)displayName { + self = [super init]; + if (self) { + _index = index; + _score = score; + _categoryName = categoryName; + _displayName = displayName; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h new file mode 100644 index 000000000..9c8b9bd2e --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -0,0 +1,116 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Represents the list of classification for a given classifier head. Typically used as a result + * for classification tasks. + **/ +NS_SWIFT_NAME(Classifications) +@interface MPPClassifications : NSObject + +/** + * The index of the classifier head these entries refer to. This is useful for multi-head models. + **/ +@property(nonatomic, readonly) NSInteger headIndex; + +/** The optional name of the classifier head, which is the corresponding tensor metadata name. **/ +@property(nonatomic, readonly, nullable) NSString *headName; + +/** An array of `MPPCategory` objects containing the predicted categories. **/ +@property(nonatomic, readonly) NSArray *categories; + +/** + * Initializes a new `MPPClassifications` object with the given head index and array of categories. + * Head name is initialized to `nil`. + * + * @param headIndex The index of the classifier head. + * @param categories An array of `MPPCategory` objects containing the predicted categories. + * + * @return An instance of `MPPClassifications` initialized with the given head index and + * array of categories. + **/ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories; + +/** + * Initializes a new `MPPClassifications` with the given head index, head name and array of + * categories. + * + * @param headIndex The index of the classifier head. + * @param headName The name of the classifier head, which is the corresponding tensor metadata + * name. + * @param categories An array of `MPPCategory` objects containing the predicted categories. + * + * @return An object of `MPPClassifications` initialized with the given head index, head name and + * array of categories. + **/ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +/** + * Represents the classification results of a model. Typically used as a result for classification + * tasks. + **/ +NS_SWIFT_NAME(ClassificationResult) +@interface MPPClassificationResult : NSObject + +/** + * An Array of `MPPClassifications` objects containing the predicted categories for each head of + * the model. + **/ +@property(nonatomic, readonly) NSArray *classifications; + +/** + * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. If it is set to the value -1, it signifies the absence of a timestamp. This is + * only used for classification on time series (e.g. audio classification). In these use cases, the + * amount of data to process might exceed the maximum size that the model can process: to solve + * this, the input data is split into multiple chunks starting at different timestamps. + **/ +@property(nonatomic, readonly) NSInteger timestampMs; + +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications and time + * stamp (in milliseconds). + * + * @param classifications An Array of `MPPClassifications` objects containing the predicted + * categories for each head of the model. + * @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + * + * @return An instance of `MPPClassificationResult` initialized with the given array of + * classifications and timestampMs. + **/ +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m new file mode 100644 index 000000000..6d42d22ca --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -0,0 +1,51 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +@implementation MPPClassifications + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories { + self = [super init]; + if (self) { + _headIndex = headIndex; + _headName = headName; + _categories = categories; + } + return self; +} + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories { + return [self initWithHeadIndex:headIndex headName:nil categories:categories]; +} + +@end + +@implementation MPPClassificationResult + +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs { + self = [super init]; + if (self) { + _classifications = classifications; + _timestampMs = timestampMs; + } + + return self; +} + +@end From 89aad67a877424d1715b0a0dbfb386cfd06e8c2a Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:07:50 +0530 Subject: [PATCH 329/469] Added iOS helpers for classification result containers --- .../ios/components/containers/utils/BUILD | 40 ++++++++++++ .../utils/sources/MPPCategory+Helpers.h | 26 ++++++++ .../utils/sources/MPPCategory+Helpers.mm | 43 +++++++++++++ .../sources/MPPClassificationResult+Helpers.h | 35 ++++++++++ .../MPPClassificationResult+Helpers.mm | 64 +++++++++++++++++++ 5 files changed, 208 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD new file mode 100644 index 000000000..e4c76ac4b --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -0,0 +1,40 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategoryHelpers", + srcs = ["sources/MPPCategory+Helpers.mm"], + hdrs = ["sources/MPPCategory+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPCategory", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPClassificationResultHelpers", + srcs = ["sources/MPPClassificationResult+Helpers.mm"], + hdrs = ["sources/MPPClassificationResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ":MPPCategoryHelpers", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h new file mode 100644 index 000000000..7580cfeeb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/classification.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm new file mode 100644 index 000000000..1c6c951d0 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +namespace { +using ClassificationProto = ::mediapipe::Classification; +} + +@implementation MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { + NSString *categoryName; + NSString *displayName; + + if (clasificationProto.has_label()) { + categoryName = [NSString stringWithCppString:clasificationProto.label()]; + } + + if (clasificationProto.has_display_name()) { + displayName = [NSString stringWithCppString:clasificationProto.display_name()]; + } + + return [[MPPCategory alloc] initWithIndex:clasificationProto.index() + score:clasificationProto.score() + categoryName:categoryName + displayName:displayName]; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h new file mode 100644 index 000000000..fde436feb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h @@ -0,0 +1,35 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto; + +@end + +@interface MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const mediapipe::tasks::components::containers::proto::ClassificationResult &) + classificationResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm new file mode 100644 index 000000000..78bc0b6a3 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -0,0 +1,64 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" + +namespace { +using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications; +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +@implementation MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const ClassificationsProto &)classificationsProto { + NSMutableArray *categories = [NSMutableArray arrayWithCapacity:(NSUInteger)classificationsProto.classification_list().classification_size()]; + for (const auto &classification : classificationsProto.classification_list().classification()) { + [categories addObject:[MPPCategory categoryWithProto:classification]]; + } + + NSString *headName; + if (classificationsProto.has_head_name()) { + headName = [NSString stringWithCppString:classificationsProto.head_name()]; + } + + return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index() + headName:headName + categories:categories]; +} + +@end + +@implementation MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const ClassificationResultProto &)classificationResultProto { + NSMutableArray *classifications = [NSMutableArray arrayWithCapacity:(NSUInteger)classificationResultProto.classifications_size()]; + for (const auto &classificationsProto : classificationResultProto.classifications()) { + [classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]]; + } + + NSInteger timestampMs; + if (classificationResultProto.has_timestamp_ms()) { + timestampMs = (NSInteger)classificationResultProto.timestamp_ms(); + } + + return [[MPPClassificationResult alloc] initWithClassifications:classifications timestampMs:timestampMs];; +} + +@end From 8f74a175d831bf09510afafc0e7ecbfb4f281a65 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:08:06 +0530 Subject: [PATCH 330/469] Removed MPPClassifierOptions and helpers --- .../tasks/ios/components/processors/BUILD | 23 ------- .../processors/sources/MPPClassifierOptions.h | 60 ------------------- .../processors/sources/MPPClassifierOptions.m | 40 ------------- .../ios/components/processors/utils/BUILD | 28 --------- .../sources/MPPClassifierOptions+Helpers.h | 26 -------- .../sources/MPPClassifierOptions+Helpers.mm | 43 ------------- 6 files changed, 220 deletions(-) delete mode 100644 mediapipe/tasks/ios/components/processors/BUILD delete mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h delete mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m delete mode 100644 mediapipe/tasks/ios/components/processors/utils/BUILD delete mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h delete mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/components/processors/BUILD b/mediapipe/tasks/ios/components/processors/BUILD deleted file mode 100644 index 165145076..000000000 --- a/mediapipe/tasks/ios/components/processors/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -objc_library( - name = "MPPClassifierOptions", - srcs = ["sources/MPPClassifierOptions.m"], - hdrs = ["sources/MPPClassifierOptions.h"], -) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h deleted file mode 100644 index 13dca4030..000000000 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -NS_ASSUME_NONNULL_BEGIN - -/** - * Holds settings for any single iOS MediaPipe classification task. - */ -NS_SWIFT_NAME(ClassifierOptions) -@interface MPPClassifierOptions : NSObject - -/** - * The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. - */ -@property(nonatomic, copy) NSString *displayNamesLocale; - -/** - * The maximum number of top-scored classification results to return. If < 0, - * all available results will be returned. If 0, an invalid argument error is - * returned. - */ -@property(nonatomic) NSInteger maxResults; - -/** - * Score threshold to override the one provided in the model metadata (if any). - * Results below this value are rejected. - */ -@property(nonatomic) float scoreThreshold; - -/** - * The allowlist of category names. If non-empty, detection results whose - * category name is not in this set will be filtered out. Duplicate or unknown - * category names are ignored. Mutually exclusive with categoryDenylist. - */ -@property(nonatomic, copy) NSArray *categoryAllowlist; - -/** - * The denylist of category names. If non-empty, detection results whose - * category name is in this set will be filtered out. Duplicate or unknown - * category names are ignored. Mutually exclusive with categoryAllowlist. - */ -@property(nonatomic, copy) NSArray *categoryDenylist; - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m deleted file mode 100644 index 01f498184..000000000 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" - -@implementation MPPClassifierOptions - -- (instancetype)init { - self = [super init]; - if (self) { - _maxResults = -1; - _scoreThreshold = 0; - } - return self; -} - -- (id)copyWithZone:(NSZone *)zone { - MPPClassifierOptions *classifierOptions = [[MPPClassifierOptions alloc] init]; - - classifierOptions.scoreThreshold = self.scoreThreshold; - classifierOptions.maxResults = self.maxResults; - classifierOptions.categoryDenylist = self.categoryDenylist; - classifierOptions.categoryAllowlist = self.categoryAllowlist; - classifierOptions.displayNamesLocale = self.displayNamesLocale; - - return classifierOptions; -} - -@end diff --git a/mediapipe/tasks/ios/components/processors/utils/BUILD b/mediapipe/tasks/ios/components/processors/utils/BUILD deleted file mode 100644 index 5344c5fdf..000000000 --- a/mediapipe/tasks/ios/components/processors/utils/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -objc_library( - name = "MPPClassifierOptionsHelpers", - srcs = ["sources/MPPClassifierOptions+Helpers.mm"], - hdrs = ["sources/MPPClassifierOptions+Helpers.h"], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", - ], -) diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h deleted file mode 100644 index e156020df..000000000 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" - -#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" - -NS_ASSUME_NONNULL_BEGIN - -@interface MPPClassifierOptions (Helpers) -- (void)copyToProto: - (mediapipe::tasks::components::processors::proto::ClassifierOptions *)classifierOptionsProto; -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm deleted file mode 100644 index 24b54fd6a..000000000 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" -#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" - -namespace { -using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; -} - -@implementation MPPClassifierOptions (Helpers) - -- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { - classifierOptionsProto->Clear(); - - if (self.displayNamesLocale) { - classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); - } - - classifierOptionsProto->set_max_results((int)self.maxResults); - classifierOptionsProto->set_score_threshold(self.scoreThreshold); - - for (NSString *category in self.categoryAllowlist) { - classifierOptionsProto->add_category_allowlist(category.cppString); - } - - for (NSString *category in self.categoryDenylist) { - classifierOptionsProto->add_category_denylist(category.cppString); - } -} - -@end From 4e38c7e623eaa7dc1219ab5291858e05703350c2 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:15:32 +0530 Subject: [PATCH 331/469] Updated documentation for MPPCommon.h --- mediapipe/tasks/ios/common/sources/MPPCommon.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 7ce791d12..09a61e20d 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN /** * @enum MPPTasksErrorCode - * This enum specifies error codes for Mediapipe Task Library. + * This enum specifies error codes for MediaPipe Task Library. * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray. */ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { From f37689fc33de306cead655cadfd283430fc5003f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:15:53 +0530 Subject: [PATCH 332/469] Updated documentation for MPPCommonUtils.m --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 2 +- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 5404a074d..69c28b916 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN -/** Error domain of Mediapipe Task related errors. */ +/** Error domain of MediaPipe Task related errors. */ extern NSString *const MPPTasksErrorDomain; /** Helper utility for the all tasks which encapsulates common functionality. */ diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 8234ac6d3..1a37f8465 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -96,7 +96,7 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; // The mapping to absl::Status::code() is done to generate a more specific error code than // MPPTasksErrorCodeError in cases when the payload can't be mapped to // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn - // returned without modification by Mediapipe cc library methods. + // returned without modification by MediaPipe cc library methods. if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { switch (status.code()) { case absl::StatusCode::kInternal: From 27ce2ec00f0fd5526c186c9b92e570a3acdca58c Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:22:11 +0530 Subject: [PATCH 333/469] Updated C++ types to camel case in MPPTaskInfo --- .../tasks/ios/core/sources/MPPTaskInfo.mm | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index 5f2290497..ae6ed2a70 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -13,6 +13,7 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" + #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" @@ -69,59 +70,59 @@ using ::mediapipe::InputStreamInfo; } - (CalculatorGraphConfig)generateGraphConfig { - CalculatorGraphConfig graph_config; + CalculatorGraphConfig graphConfig; - Node *task_subgraph_node = graph_config.add_node(); - task_subgraph_node->set_calculator(self.taskGraphName.cppString); - [self.taskOptions copyToProto:task_subgraph_node->mutable_options()]; + Node *taskSubgraphNode = graphConfig.add_node(); + taskSubgraphNode->set_calculator(self.taskGraphName.cppString); + [self.taskOptions copyToProto:taskSubgraphNode->mutable_options()]; for (NSString *outputStream in self.outputStreams) { - auto cpp_output_stream = std::string(outputStream.cppString); - task_subgraph_node->add_output_stream(cpp_output_stream); - graph_config.add_output_stream(cpp_output_stream); + auto cppOutputStream = std::string(outputStream.cppString); + taskSubgraphNode->add_output_stream(cppOutputStream); + graphConfig.add_output_stream(cppOutputStream); } if (!self.enableFlowLimiting) { for (NSString *inputStream in self.inputStreams) { - auto cpp_input_stream = inputStream.cppString; - task_subgraph_node->add_input_stream(cpp_input_stream); - graph_config.add_input_stream(cpp_input_stream); + auto cppInputStream = inputStream.cppString; + taskSubgraphNode->add_input_stream(cppInputStream); + graphConfig.add_input_stream(cppInputStream); } - return graph_config; + return graphConfig; } - Node *flow_limit_calculator_node = graph_config.add_node(); + Node *flowLimitCalculatorNode = graphConfig.add_node(); - flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + flowLimitCalculatorNode->set_calculator("FlowLimiterCalculator"); - InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); - input_stream_info->set_tag_index("FINISHED"); - input_stream_info->set_back_edge(true); + InputStreamInfo *inputStreamInfo = flowLimitCalculatorNode->add_input_stream_info(); + inputStreamInfo->set_tag_index("FINISHED"); + inputStreamInfo->set_back_edge(true); - FlowLimiterCalculatorOptions *flow_limit_calculator_options = - flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions *flowLimitCalculatorOptions = + flowLimitCalculatorNode->mutable_options()->MutableExtension( FlowLimiterCalculatorOptions::ext); - flow_limit_calculator_options->set_max_in_flight(1); - flow_limit_calculator_options->set_max_in_queue(1); + flowLimitCalculatorOptions->set_max_in_flight(1); + flowLimitCalculatorOptions->set_max_in_queue(1); for (NSString *inputStream in self.inputStreams) { - graph_config.add_input_stream(inputStream.cppString); + graphConfig.add_input_stream(inputStream.cppString); NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; - flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + flowLimitCalculatorNode->add_input_stream(strippedInputStream.cppString); NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; - task_subgraph_node->add_input_stream(taskInputStream.cppString); + taskSubgraphNode->add_input_stream(taskInputStream.cppString); NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; - flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + flowLimitCalculatorNode->add_output_stream(strippedTaskInputStream.cppString); } NSString *firstOutputStream = self.outputStreams[0]; - auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; - flow_limit_calculator_node->add_input_stream(finished_output_stream); + auto finishedOutputStream = "FINISHED:" + firstOutputStream.cppString; + flowLimitCalculatorNode->add_input_stream(finishedOutputStream); - return graph_config; + return graphConfig; } + (NSString *)stripTagIndex:(NSString *)tagIndexName { From 61d16b284b9d0b063a61f31ae9565532f6e69798 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:23:22 +0530 Subject: [PATCH 334/469] Updated comments in MPPTaskOptions.h --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index ee2f7d032..e10678348 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -25,7 +25,7 @@ NS_SWIFT_NAME(TaskOptions) @interface MPPTaskOptions : NSObject /** - * Base options for configuring the Mediapipe task. + * Base options for configuring the MediaPipe task. */ @property(nonatomic, copy) MPPBaseOptions *baseOptions; From 16f9831c3fece5d9907868ea60c7bb7fb0c01cc5 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:23:37 +0530 Subject: [PATCH 335/469] Updated formatting in MPPTaskOptions.m --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.m | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m index fe74517c3..ad11bbc6e 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -13,6 +13,7 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" @implementation MPPTaskOptions From bc1b069edf818e9431697ceb040cc1c105984ef3 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:24:41 +0530 Subject: [PATCH 336/469] Updated property name in MPPTaskResult --- mediapipe/tasks/ios/core/sources/MPPTaskResult.h | 4 ++-- mediapipe/tasks/ios/core/sources/MPPTaskResult.m | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index d15d4f258..4ee7b2fc6 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -26,11 +26,11 @@ NS_SWIFT_NAME(TaskResult) /** * Timestamp that is associated with the task result object. */ -@property(nonatomic, assign, readonly) long timestamp; +@property(nonatomic, assign, readonly) NSInteger timestampMs; - (instancetype)init NS_UNAVAILABLE; -- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER; +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m index 7088eb246..6c08014ff 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -16,16 +16,16 @@ @implementation MPPTaskResult -- (instancetype)initWithTimestamp:(long)timestamp { +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { self = [super init]; if (self) { - _timestamp = timestamp; + _timestampMs = timestampMs; } return self; } - (id)copyWithZone:(NSZone *)zone { - return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp]; + return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; } @end From c6bae99a2fc120c4de58f352dc64b6dc0aff728b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:25:56 +0530 Subject: [PATCH 337/469] Updated formatting in MPPTextPacketCreator.mm --- mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm index ca86e7a0b..fb59b363d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm @@ -13,6 +13,7 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" + #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" namespace { From b6bcc35adef1ea3d27af6f35488a1608a4670be5 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:36:15 +0530 Subject: [PATCH 338/469] Added provision for packets callback in iOS task runner --- mediapipe/tasks/ios/core/BUILD | 12 +++-- .../tasks/ios/core/sources/MPPTaskRunner.h | 52 +++++++++++++++++-- .../tasks/ios/core/sources/MPPTaskRunner.mm | 12 ++++- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 434d20085..757e2d4cc 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -56,12 +56,12 @@ objc_library( deps = [ ":MPPTaskOptions", ":MPPTaskOptionsProtocol", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", ], ) @@ -88,8 +88,10 @@ objc_library( "-std=c++17", ], deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", + "//mediapipe/tasks/cc/core:task_runner", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], ) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 2b9f2ecdb..a1b1dfad4 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -20,23 +20,65 @@ NS_ASSUME_NONNULL_BEGIN /** - * This class is used to create and call appropriate methods on the C++ Task Runner. - */ + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any MediaPipe task. + * + * An instance of the newly created C++ task runner will be stored until this class is destroyed. + * When methods are called for processing (performing inference), closing etc., on this class, + * internally the appropriate methods will be called on the C++ task runner instance to execute the + * appropriate actions. For each type of task, a subclass of this class must be defined to add any + * additional functionality. For eg:, vision tasks must create an `MPPVisionTaskRunner` and provide + * additional functionality. An instance of `MPPVisionTaskRunner` can in turn be used by the each + * vision task for creation and execution of the task. Please see the documentation for the C++ Task + * Runner for more details on how the taks runner operates. + **/ @interface MPPTaskRunner : NSObject /** - * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. + * Initializes a new `MPPTaskRunner` with the MediaPipe calculator configuration proto and an + * optional C++ packets callback. + * + * You can pass `nullptr` for `packetsCallback` in case the mode of operation requested by the user + * is synchronous. + * + * If the task is operating in asynchronous mode, any iOS MediaPipe task that uses the + * `MPPTaskRunner` must define a C++ callback function to obtain the results of inference + * asynchronously and deliver the results to the user. To accomplish this, the callback function + * should in turn invoke the block provided by the user in the task options supplied to create the + * task. Please see the documentation of the C++ Task Runner for more information on the synchronous + * and asynchronous modes of operation. * * @param graphConfig A mediapipe task graph config proto. + * @param packetsCallback An optional C++ callback function that takes a list of output packets as + * the input argument. If provided, the callback must in turn call the block provided by the user in + * the appropriate task options. * - * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. - */ + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto and optional + * packetsCallback. + **/ - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + packetsCallback: + (mediapipe::tasks::core::PacketsCallback)packetsCallback error:(NSError **)error NS_DESIGNATED_INITIALIZER; +/** + * A synchronous method for processing batch data or offline streaming data. This method is designed + * for processing either batch data such as unrelated images and texts or offline streaming data + * such as the decoded frames from a video file or audio file. The call blocks the current + * thread until a failure status or a successful result is returned. If the input packets have no + * timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp + * is set in the input packets, the caller must ensure that the input packet timestamps are greater + * than the timestamps of the previous invocation. This method is thread-unsafe and it is the + * caller's responsibility to synchronize access to this method across multiple threads and to + * ensure that the input packet timestamps are in order. + **/ - (absl::StatusOr)process: (const mediapipe::tasks::core::PacketMap &)packetMap; +/** + * Shuts down the C++ task runner. After the runner is closed, any calls that send input data to the + * runner are illegal and will receive errors. + **/ - (absl::Status)close; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index c5c307fd5..a77f206b2 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -13,11 +13,17 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" + #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "tensorflow/lite/core/api/op_resolver.h" + namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver; using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @@ -30,15 +36,17 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; @implementation MPPTaskRunner - (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + packetsCallback:(PacketsCallback)packetsCallback error:(NSError **)error { self = [super init]; if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig), + absl::make_unique(), + std::move(packetsCallback)); if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { return nil; } - _cppTaskRunner = std::move(taskRunnerResult.value()); } return self; From b91b485035545f3263cb88ade8444eb6fc32d407 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:36:28 +0530 Subject: [PATCH 339/469] Added MPPBaseOptions Helpers --- mediapipe/tasks/ios/core/utils/BUILD | 27 ++++++++++++ .../utils/sources/MPPBaseOptions+Helpers.h | 26 +++++++++++ .../utils/sources/MPPBaseOptions+Helpers.mm | 44 +++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 mediapipe/tasks/ios/core/utils/BUILD create mode 100644 mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/core/utils/BUILD b/mediapipe/tasks/ios/core/utils/BUILD new file mode 100644 index 000000000..1cfc75e6a --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/BUILD @@ -0,0 +1,27 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseOptionsHelpers", + srcs = ["sources/MPPBaseOptions+Helpers.mm"], + hdrs = ["sources/MPPBaseOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPBaseOptions", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h new file mode 100644 index 000000000..d52df2ae4 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPBaseOptions (Helpers) + +- (void)copyToProto:(mediapipe::tasks::core::proto::BaseOptions *)baseOptionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm new file mode 100644 index 000000000..3fd8fbda3 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +namespace { +using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; +} + +@implementation MPPBaseOptions (Helpers) + +- (void)copyToProto:(BaseOptionsProto *)baseOptionsProto { + baseOptionsProto->Clear(); + + if (self.modelAssetPath) { + baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); + } + + switch (self.delegate) { + case MPPDelegateCPU: { + baseOptionsProto->mutable_acceleration()->mutable_tflite(); + break; + } + case MPPDelegateGPU: { + // TODO: Provide an implementation for GPU Delegate. + [NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."]; + } + default: + break; + } +} + +@end From 14e3de49ad09d9fc33cfe95ffb8038e473e132cf Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:37:31 +0530 Subject: [PATCH 340/469] Added MPPTextTaskRunner --- mediapipe/tasks/ios/text/core/BUILD | 31 +++++++++++++ .../ios/text/core/sources/MPPTextTaskRunner.h | 43 +++++++++++++++++++ .../text/core/sources/MPPTextTaskRunner.mm | 29 +++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 mediapipe/tasks/ios/text/core/BUILD create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm diff --git a/mediapipe/tasks/ios/text/core/BUILD b/mediapipe/tasks/ios/text/core/BUILD new file mode 100644 index 000000000..bf88f5734 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextTaskRunner", + srcs = ["sources/MPPTextTaskRunner.mm"], + hdrs = ["sources/MPPTextTaskRunner.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskRunner", + ], +) + diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h new file mode 100644 index 000000000..e3df3de9d --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any MediaPipe text task. + **/ +@interface MPPTextTaskRunner : MPPTaskRunner + +/** + * Initializes a new `MPPTextTaskRunner` with the MediaPipe calculator config proto. + * + * @param graphConfig A MediaPipe calculator config proto. + * + * @return An instance of `MPPTextTaskRunner` initialized to the given MediaPipe calculator config + * proto. + **/ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm new file mode 100644 index 000000000..956448c17 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm @@ -0,0 +1,29 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +} // namespace + +@implementation MPPTextTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super initWithCalculatorGraphConfig:graphConfig packetsCallback:nullptr error:error]; + return self; +} + +@end From 2cce88080e8d320a547a870e9bf3f2f9f86fa2e0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 6 Jan 2023 15:18:12 -0800 Subject: [PATCH 341/469] Internal change PiperOrigin-RevId: 500271109 --- mediapipe/calculators/image/scale_image_utils.cc | 6 ++++++ mediapipe/calculators/image/scale_image_utils_test.cc | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 490d0336a..86a53ffc5 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -142,6 +142,9 @@ absl::Status FindOutputDimensions(int input_width, // static_cast(input_height)); try_width = (try_width / 2) * 2; try_height = (try_height / 2) * 2; + // The output width/height should be greater than 0. + try_width = std::max(try_width, 1); + try_height = std::max(try_height, 1); if (target_height <= 0 || try_height <= target_height) { // The resulting height based on the target width and aspect ratio @@ -160,6 +163,9 @@ absl::Status FindOutputDimensions(int input_width, // static_cast(input_width)); try_width = (try_width / 2) * 2; try_height = (try_height / 2) * 2; + // The output width/height should be greater than 0. + try_width = std::max(try_width, 1); + try_height = std::max(try_height, 1); if (target_width <= 0 || try_width <= target_width) { // The resulting width based on the target width and aspect ratio diff --git a/mediapipe/calculators/image/scale_image_utils_test.cc b/mediapipe/calculators/image/scale_image_utils_test.cc index bda1fa4d6..b4810071c 100644 --- a/mediapipe/calculators/image/scale_image_utils_test.cc +++ b/mediapipe/calculators/image/scale_image_utils_test.cc @@ -124,6 +124,16 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsPreserveRatio) { &output_width, &output_height)); EXPECT_EQ(151, output_width); EXPECT_EQ(101, output_height); + // Scale to height 1. + MP_ASSERT_OK(FindOutputDimensions(10000, 10, 100, 0, 0, true, 2, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(1, output_height); + // Scale to width 1. + MP_ASSERT_OK(FindOutputDimensions(10, 10000, 0, 100, 0, true, 2, + &output_width, &output_height)); + EXPECT_EQ(1, output_width); + EXPECT_EQ(100, output_height); } // Tests scaling without keeping the aspect ratio fixed. From 9b34a105cfc3ca01a2a45afc011d613daaab7f26 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 18:15:34 -0800 Subject: [PATCH 342/469] Do not depend on Image methods in TaskRunner PiperOrigin-RevId: 500299571 --- .../tasks/web/audio/audio_classifier/BUILD | 1 + .../audio_classifier/audio_classifier.ts | 3 ++- .../tasks/web/audio/audio_embedder/BUILD | 1 + .../audio/audio_embedder/audio_embedder.ts | 3 ++- mediapipe/tasks/web/core/BUILD | 2 -- mediapipe/tasks/web/core/task_runner.ts | 21 +++++++------------ mediapipe/tasks/web/core/task_runner_test.ts | 20 +++++++----------- .../text/text_classifier/text_classifier.ts | 4 ++-- .../web/text/text_embedder/text_embedder.ts | 4 ++-- mediapipe/tasks/web/vision/core/BUILD | 2 ++ .../vision/core/vision_task_runner.test.ts | 4 ++-- .../web/vision/core/vision_task_runner.ts | 15 ++++++++++++- .../tasks/web/vision/gesture_recognizer/BUILD | 2 +- .../gesture_recognizer/gesture_recognizer.ts | 4 ++-- .../gesture_recognizer_test.ts | 4 ++-- .../tasks/web/vision/hand_landmarker/BUILD | 2 +- .../vision/hand_landmarker/hand_landmarker.ts | 4 ++-- .../hand_landmarker/hand_landmarker_test.ts | 5 +++-- .../image_classifier/image_classifier.ts | 4 ++-- .../vision/image_embedder/image_embedder.ts | 4 ++-- .../vision/object_detector/object_detector.ts | 4 ++-- 21 files changed, 61 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 24ef31feb..a94b4931d 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -27,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 51573f50a..92fca93ad 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,6 +22,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -98,7 +99,7 @@ export class AudioClassifier extends AudioTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 0817776c5..68a7f7bd5 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -27,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 6a4b8ce39..2e210f969 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -24,6 +24,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; +import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -100,7 +101,7 @@ export class AudioEmbedder extends AudioTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index c0d10d28b..371c75da0 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -22,7 +22,6 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", ], @@ -57,7 +56,6 @@ mediapipe_ts_library( deps = [ ":core", ":task_runner", - ":task_runner_test_utils", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index ffb538b52..a3df7adf5 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -19,8 +19,7 @@ import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; -import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; -import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; +import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor} from '../../../web/graph_runner/graph_runner'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; import {WasmFileset} from './wasm_fileset'; @@ -29,10 +28,12 @@ import {WasmFileset} from './wasm_fileset'; const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const GraphRunnerImageLibType = - SupportModelResourcesGraphService(SupportImage(GraphRunner)); -/** An implementation of the GraphRunner that supports image operations */ -export class GraphRunnerImageLib extends GraphRunnerImageLibType {} +const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner); +/** + * An implementation of the GraphRunner that exposes the resource graph + * service. + */ +export class CachedGraphRunner extends CachedGraphRunnerType {} /** * Creates a new instance of a Mediapipe Task. Determines if SIMD is @@ -64,7 +65,6 @@ export async function createTaskRunner( /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; - protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; /** @@ -79,12 +79,7 @@ export abstract class TaskRunner { } /** @hideconstructor protected */ - constructor( - wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, - graphRunner?: GraphRunnerImageLib) { - this.graphRunner = - graphRunner ?? new GraphRunnerImageLib(wasmModule, glCanvas); - + constructor(protected readonly graphRunner: CachedGraphRunner) { // Disables the automatic render-to-screen code, which allows for pure // CPU processing. this.graphRunner.setAutoRenderToScreen(false); diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index a55ac04d7..684beb70c 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -18,11 +18,10 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; -import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource URL builder -import {GraphRunnerImageLib} from './task_runner'; +import {CachedGraphRunner} from './task_runner'; import {TaskRunnerOptions} from './task_runner_options.d'; class TaskRunnerFake extends TaskRunner { @@ -32,18 +31,15 @@ class TaskRunnerFake extends TaskRunner { baseOptions = new BaseOptionsProto(); static createFake(): TaskRunnerFake { - const wasmModule = createSpyWasmModule(); - return new TaskRunnerFake(wasmModule); + return new TaskRunnerFake(); } - constructor(wasmModuleFake: SpyWasmModule) { - super( - wasmModuleFake, /* glCanvas= */ null, - jasmine.createSpyObj([ - 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', - 'registerModelResourcesGraphService', 'attachErrorListener' - ])); - const graphRunner = this.graphRunner as jasmine.SpyObj; + constructor() { + super(jasmine.createSpyObj([ + 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', + 'registerModelResourcesGraphService', 'attachErrorListener' + ])); + const graphRunner = this.graphRunner as jasmine.SpyObj; expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled(); expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); graphRunner.attachErrorListener.and.callFake(listener => { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 981438625..6aef1b3e4 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -96,7 +96,7 @@ export class TextClassifier extends TaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 7aa0aa6b9..db7986dec 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -23,7 +23,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -100,7 +100,7 @@ export class TextEmbedder extends TaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 03958a819..3574483df 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -20,7 +20,9 @@ mediapipe_ts_library( ":vision_task_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", ], ) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index d77cc4fed..f3f25070e 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -21,13 +21,13 @@ import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_u import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; -import {VisionTaskRunner} from './vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from './vision_task_runner'; class VisionTaskRunnerFake extends VisionTaskRunner { baseOptions = new BaseOptionsProto(); constructor() { - super(createSpyWasmModule(), /* glCanvas= */ null); + super(new VisionGraphRunner(createSpyWasmModule(), /* glCanvas= */ null)); } protected override process(): void {} diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 952990326..c3e0d3c7e 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -15,12 +15,25 @@ */ import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib'; +import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; import {VisionTaskOptions} from './vision_task_options'; +// tslint:disable-next-line:enforce-name-casing +const GraphRunnerVisionType = + SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class VisionGraphRunner extends GraphRunnerVisionType {} + /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { + /** @hideconstructor protected */ + constructor(protected override readonly graphRunner: VisionGraphRunner) { + super(graphRunner); + } + /** Configures the shared options of a vision task. */ override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index aa2f9c366..5fdf9b43e 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -67,8 +67,8 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:vision_task_runner", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index c77f2c67a..8d36ed89c 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -30,7 +30,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -131,7 +131,7 @@ export class GestureRecognizer extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index 3611c3a7d..3699033b2 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -18,8 +18,8 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'; @@ -98,7 +98,7 @@ class GestureRecognizerFake extends GestureRecognizer implements spyOn(this.graphRunner, 'addProtoToStream'); } - getGraphRunner(): GraphRunnerImageLib { + getGraphRunner(): VisionGraphRunner { return this.graphRunner; } } diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index d1f1e48f3..e7083a050 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -62,8 +62,8 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:vision_task_runner", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 24cf9a402..5db6d48f5 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -26,7 +26,7 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -119,7 +119,7 @@ export class HandLandmarker extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 1a813c6f7..bce0eac02 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -18,12 +18,13 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {HandLandmarker} from './hand_landmarker'; import {HandLandmarkerOptions} from './hand_landmarker_options'; + // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern @@ -87,7 +88,7 @@ class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { spyOn(this.graphRunner, 'addProtoToStream'); } - getGraphRunner(): GraphRunnerImageLib { + getGraphRunner(): VisionGraphRunner { return this.graphRunner; } } diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 9298a860c..4a2be5566 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -22,7 +22,7 @@ import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_cla import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -97,7 +97,7 @@ export class ImageClassifier extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index cf0bd8c5d..4651ae4ce 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -24,7 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -99,7 +99,7 @@ export class ImageEmbedder extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e4c51de08..ac489ec00 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -20,7 +20,7 @@ import {Detection as DetectionProto} from '../../../../framework/formats/detecti import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -96,7 +96,7 @@ export class ObjectDetector extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } From 9055effddd35e0424db2a11a81445f32f6badae8 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 20:52:30 -0800 Subject: [PATCH 343/469] Add ImageProcessingOptions to all Vision Tasks PiperOrigin-RevId: 500323261 --- .../cc/vision/core/image_processing_options.h | 2 +- .../tasks/web/components/containers/BUILD | 5 + .../tasks/web/components/containers/rect.d.ts | 41 +++++ .../tasks/web/core/task_runner_test_utils.ts | 6 +- mediapipe/tasks/web/vision/core/BUILD | 12 ++ .../vision/core/image_processing_options.d.ts | 42 +++++ .../vision/core/vision_task_runner.test.ts | 158 ++++++++++++++++-- .../web/vision/core/vision_task_runner.ts | 95 +++++++++-- .../tasks/web/vision/gesture_recognizer/BUILD | 2 +- .../gesture_recognizer/gesture_recognizer.ts | 48 +++--- .../tasks/web/vision/hand_landmarker/BUILD | 2 +- .../vision/hand_landmarker/hand_landmarker.ts | 46 ++--- .../tasks/web/vision/image_classifier/BUILD | 1 + .../image_classifier/image_classifier.ts | 43 ++--- .../tasks/web/vision/image_embedder/BUILD | 1 + .../vision/image_embedder/image_embedder.ts | 45 ++--- .../tasks/web/vision/object_detector/BUILD | 1 + .../vision/object_detector/object_detector.ts | 42 +++-- 18 files changed, 460 insertions(+), 132 deletions(-) create mode 100644 mediapipe/tasks/web/components/containers/rect.d.ts create mode 100644 mediapipe/tasks/web/vision/core/image_processing_options.d.ts diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 1983272fc..e2647be71 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -28,7 +28,7 @@ namespace core { // Options for image processing. // // If both region-or-interest and rotation are specified, the crop around the -// region-of-interest is extracted first, the the specified rotation is applied +// region-of-interest is extracted first, then the specified rotation is applied // to the crop. struct ImageProcessingOptions { // The optional region-of-interest to crop from the image. If not specified, diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index fb0fdff16..a0db59d0b 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -24,3 +24,8 @@ mediapipe_ts_declaration( name = "embedding_result", srcs = ["embedding_result.d.ts"], ) + +mediapipe_ts_declaration( + name = "rect", + srcs = ["rect.d.ts"], +) diff --git a/mediapipe/tasks/web/components/containers/rect.d.ts b/mediapipe/tasks/web/components/containers/rect.d.ts new file mode 100644 index 000000000..9afece9ca --- /dev/null +++ b/mediapipe/tasks/web/components/containers/rect.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + */ +export declare interface Rect { + left: number; + top: number; + right: number; + bottom: number; +} + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + * + * The coordinates are normalized with respect to the image dimensions, i.e. + * generally in [0,1] but they may exceed these bounds if describing a region + * overlapping the image. The origin is on the top-left corner of the image. + */ +export declare interface RectF { + left: number; + top: number; + right: number; + bottom: number; +} diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 838b3f585..62dd0463a 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -32,12 +32,14 @@ export declare type SpyWasmModule = jasmine.SpyObj; * in pure JS/TS (and optionally spy on the calls). */ export function createSpyWasmModule(): SpyWasmModule { - return jasmine.createSpyObj([ + const spyWasmModule = jasmine.createSpyObj([ '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', '_attachProtoVectorListener', '_free', '_waitUntilIdle', '_addStringToInputStream', '_registerModelResourcesGraphService', - '_configureAudio' + '_configureAudio', '_malloc', '_addProtoToInputStream' ]); + spyWasmModule.HEAPU8 = jasmine.createSpyObj(['set']); + return spyWasmModule; } /** diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 3574483df..a0a008122 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -5,6 +5,14 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", package(default_visibility = ["//mediapipe/tasks:internal"]) +mediapipe_ts_declaration( + name = "image_processing_options", + srcs = ["image_processing_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:rect", + ], +) + mediapipe_ts_declaration( name = "vision_task_options", srcs = ["vision_task_options.d.ts"], @@ -17,7 +25,9 @@ mediapipe_ts_library( name = "vision_task_runner", srcs = ["vision_task_runner.ts"], deps = [ + ":image_processing_options", ":vision_task_options", + "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", @@ -31,8 +41,10 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":image_processing_options", ":vision_task_options", ":vision_task_runner", + "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/core/image_processing_options.d.ts b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts new file mode 100644 index 000000000..b76731546 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts @@ -0,0 +1,42 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {RectF} from '../../../../tasks/web/components/containers/rect'; + +/** + * Options for image processing. + * + * If both region-or-interest and rotation are specified, the crop around the + * region-of-interest is extracted first, then the specified rotation is applied + * to the crop. + */ +export declare interface ImageProcessingOptions { + /** + * The optional region-of-interest to crop from the image. If not specified, + * the full image is used. + * + * Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. + */ + regionOfInterest?: RectF; + + /** + * The rotation to apply to the image (or cropped region-of-interest), in + * degrees clockwise. + * + * The rotation must be a multiple (positive or negative) of 90°. + */ + rotationDegrees?: number; +} diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index f3f25070e..a48381038 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -16,21 +16,62 @@ import 'jasmine'; +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; import {VisionGraphRunner, VisionTaskRunner} from './vision_task_runner'; -class VisionTaskRunnerFake extends VisionTaskRunner { + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; + +const IMAGE = {} as unknown as HTMLImageElement; +const TIMESTAMP = 42; + +class VisionTaskRunnerFake extends VisionTaskRunner { baseOptions = new BaseOptionsProto(); + fakeGraphRunner: jasmine.SpyObj; + expectedImageSource?: ImageSource; + expectedNormalizedRect?: NormalizedRect; constructor() { - super(new VisionGraphRunner(createSpyWasmModule(), /* glCanvas= */ null)); - } + super( + jasmine.createSpyObj([ + 'addProtoToStream', 'addGpuBufferAsImageToStream', + 'setAutoRenderToScreen', 'registerModelResourcesGraphService', + 'finishProcessing' + ]), + IMAGE_STREAM, NORM_RECT_STREAM); - protected override process(): void {} + this.fakeGraphRunner = + this.graphRunner as unknown as jasmine.SpyObj; + + (this.graphRunner.addProtoToStream as jasmine.Spy) + .and.callFake((serializedData, type, streamName, timestamp) => { + expect(type).toBe('mediapipe.NormalizedRect'); + expect(streamName).toBe(NORM_RECT_STREAM); + expect(timestamp).toBe(TIMESTAMP); + + const actualNormalizedRect = + NormalizedRect.deserializeBinary(serializedData); + expect(actualNormalizedRect.toObject()) + .toEqual(this.expectedNormalizedRect!.toObject()); + }); + + (this.graphRunner.addGpuBufferAsImageToStream as jasmine.Spy) + .and.callFake((imageSource, streamName, timestamp) => { + expect(streamName).toBe(IMAGE_STREAM); + expect(timestamp).toBe(TIMESTAMP); + expect(imageSource).toBe(this.expectedImageSource!); + }); + } protected override refreshGraph(): void {} @@ -38,12 +79,31 @@ class VisionTaskRunnerFake extends VisionTaskRunner { return this.applyOptions(options); } - override processImageData(image: ImageSource): void { - super.processImageData(image); + override processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { + super.processImageData(image, imageProcessingOptions); } - override processVideoData(imageFrame: ImageSource, timestamp: number): void { - super.processVideoData(imageFrame, timestamp); + override processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + super.processVideoData(imageFrame, imageProcessingOptions, timestamp); + } + + expectNormalizedRect( + xCenter: number, yCenter: number, width: number, height: number): void { + const rect = new NormalizedRect(); + rect.setXCenter(xCenter); + rect.setYCenter(yCenter); + rect.setWidth(width); + rect.setHeight(height); + this.expectedNormalizedRect = rect; + } + + expectImage(imageSource: ImageSource): void { + this.expectedImageSource = imageSource; } } @@ -51,6 +111,7 @@ describe('VisionTaskRunner', () => { let visionTaskRunner: VisionTaskRunnerFake; beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions( {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); @@ -72,7 +133,8 @@ describe('VisionTaskRunner', () => { await visionTaskRunner.setOptions({runningMode: 'video'}); // Clear running mode - await visionTaskRunner.setOptions({runningMode: undefined}); + await visionTaskRunner.setOptions( + {runningMode: /* imageProcessingOptions= */ undefined}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); @@ -80,20 +142,90 @@ describe('VisionTaskRunner', () => { it('cannot process images with video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); expect(() => { - visionTaskRunner.processImageData({} as HTMLImageElement); + visionTaskRunner.processImageData( + IMAGE, /* imageProcessingOptions= */ undefined); }).toThrowError(/Task is not initialized with image mode./); }); it('cannot process video with image mode', async () => { // Use default for `useStreamMode` expect(() => { - visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); }).toThrowError(/Task is not initialized with video mode./); // Explicitly set to image mode await visionTaskRunner.setOptions({runningMode: 'image'}); expect(() => { - visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); }).toThrowError(/Task is not initialized with video mode./); }); + + it('sends packets to graph', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.5, 0.5, 1, 1); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); + }); + + it('sends packets to graph with image processing options', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.3, 0.6, 0.2, 0.4); + visionTaskRunner.processVideoData( + IMAGE, + {regionOfInterest: {left: 0.2, right: 0.4, top: 0.4, bottom: 0.8}}, + TIMESTAMP); + }); + + describe('validates processing options', () => { + it('with left > right', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.2, + right: 0.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with top > bottom', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.2, + bottom: 0.1, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with out of range values', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 1.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF values to be in [0,1].'); + }); + + it('with non-90 degree rotation', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); + }).toThrowError('Expected rotation to be a multiple of 90°.'); + }); + }); }); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index c3e0d3c7e..9adc810fc 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -14,7 +14,9 @@ * limitations under the License. */ +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; @@ -27,10 +29,26 @@ const GraphRunnerVisionType = /** An implementation of the GraphRunner that supports image operations */ export class VisionGraphRunner extends GraphRunnerVisionType {} +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends TaskRunner { - /** @hideconstructor protected */ - constructor(protected override readonly graphRunner: VisionGraphRunner) { +export abstract class VisionTaskRunner extends TaskRunner { + /** + * Constructor to initialize a `VisionTaskRunner`. + * + * @param graphRunner the graph runner for this task. + * @param imageStreamName the name of the input image stream. + * @param normRectStreamName the name of the input normalized rect image + * stream used to provide (mandatory) rotation and (optional) + * region-of-interest. + * + * @hideconstructor protected + */ + constructor( + protected override readonly graphRunner: VisionGraphRunner, + private readonly imageStreamName: string, + private readonly normRectStreamName: string) { super(graphRunner); } @@ -44,27 +62,84 @@ export abstract class VisionTaskRunner extends TaskRunner { return super.applyOptions(options); } - /** Sends an image packet to the graph and awaits results. */ - protected abstract process(input: ImageSource, timestamp: number): T; - /** Sends a single image to the graph and awaits results. */ - protected processImageData(image: ImageSource): T { + protected processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { if (!!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with image mode. ' + '\'runningMode\' must be set to \'image\'.'); } - return this.process(image, performance.now()); + this.process(image, imageProcessingOptions, performance.now()); } /** Sends a single video frame to the graph and awaits results. */ - protected processVideoData(imageFrame: ImageSource, timestamp: number): T { + protected processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { if (!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with video mode. ' + '\'runningMode\' must be set to \'video\'.'); } - return this.process(imageFrame, timestamp); + this.process(imageFrame, imageProcessingOptions, timestamp); + } + + private convertToNormalizedRect(imageProcessingOptions?: + ImageProcessingOptions): NormalizedRect { + const normalizedRect = new NormalizedRect(); + + if (imageProcessingOptions?.regionOfInterest) { + const roi = imageProcessingOptions.regionOfInterest; + + if (roi.left >= roi.right || roi.top >= roi.bottom) { + throw new Error('Expected RectF with left < right and top < bottom.'); + } + if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { + throw new Error('Expected RectF values to be in [0,1].'); + } + + normalizedRect.setXCenter((roi.left + roi.right) / 2.0); + normalizedRect.setYCenter((roi.top + roi.bottom) / 2.0); + normalizedRect.setWidth(roi.right - roi.left); + normalizedRect.setHeight(roi.bottom - roi.top); + return normalizedRect; + } else { + normalizedRect.setXCenter(0.5); + normalizedRect.setYCenter(0.5); + normalizedRect.setWidth(1); + normalizedRect.setHeight(1); + } + + if (imageProcessingOptions?.rotationDegrees) { + if (imageProcessingOptions?.rotationDegrees % 90 !== 0) { + throw new Error( + 'Expected rotation to be a multiple of 90°.', + ); + } + + // Convert to radians anti-clockwise. + normalizedRect.setRotation( + -Math.PI * imageProcessingOptions.rotationDegrees / 180.0); + } + + return normalizedRect; + } + + /** Runs the graph and blocks on the response. */ + private process( + imageSource: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + const normalizedRect = this.convertToNormalizedRect(imageProcessingOptions); + this.graphRunner.addProtoToStream( + normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', + this.normRectStreamName, timestamp); + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, this.imageStreamName, timestamp ?? performance.now()); + this.finishProcessing(); } } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 5fdf9b43e..9156e89b7 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -20,7 +20,6 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", - "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", @@ -33,6 +32,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8d36ed89c..e0c6affcb 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -18,7 +18,6 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; @@ -30,6 +29,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -57,15 +57,8 @@ const DEFAULT_NUM_HANDS = 1; const DEFAULT_SCORE_THRESHOLD = 0.5; const DEFAULT_CATEGORY_INDEX = -1; -const FULL_IMAGE_RECT = new NormalizedRect(); -FULL_IMAGE_RECT.setXCenter(0.5); -FULL_IMAGE_RECT.setYCenter(0.5); -FULL_IMAGE_RECT.setWidth(1); -FULL_IMAGE_RECT.setHeight(1); - /** Performs hand gesture recognition on images. */ -export class GestureRecognizer extends - VisionTaskRunner { +export class GestureRecognizer extends VisionTaskRunner { private gestures: Category[][] = []; private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; @@ -131,7 +124,9 @@ export class GestureRecognizer extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); @@ -228,10 +223,16 @@ export class GestureRecognizer extends * GestureRecognizer is created with running mode `image`. * * @param image A single image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected gestures. */ - recognize(image: ImageSource): GestureRecognizerResult { - return this.processImageData(image); + recognize( + image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + GestureRecognizerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); } /** @@ -241,28 +242,27 @@ export class GestureRecognizer extends * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected gestures. */ - recognizeForVideo(videoFrame: ImageSource, timestamp: number): + recognizeForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): GestureRecognizerResult { - return this.processVideoData(videoFrame, timestamp); + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); } - /** Runs the gesture recognition and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - GestureRecognizerResult { + private resetResults(): void { this.gestures = []; this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; + } - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, IMAGE_STREAM, timestamp); - this.graphRunner.addProtoToStream( - FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', - NORM_RECT_STREAM, timestamp); - this.finishProcessing(); - + private processResults(): GestureRecognizerResult { if (this.gestures.length === 0) { // If no gestures are detected in the image, just return an empty list return { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index e7083a050..c5687ee2f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -20,7 +20,6 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", - "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", @@ -28,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 5db6d48f5..e238bc96f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -18,7 +18,6 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; @@ -26,6 +25,7 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -51,14 +51,9 @@ const HAND_LANDMARKER_GRAPH = const DEFAULT_NUM_HANDS = 1; const DEFAULT_SCORE_THRESHOLD = 0.5; const DEFAULT_CATEGORY_INDEX = -1; -const FULL_IMAGE_RECT = new NormalizedRect(); -FULL_IMAGE_RECT.setXCenter(0.5); -FULL_IMAGE_RECT.setYCenter(0.5); -FULL_IMAGE_RECT.setWidth(1); -FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ -export class HandLandmarker extends VisionTaskRunner { +export class HandLandmarker extends VisionTaskRunner { private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -119,7 +114,9 @@ export class HandLandmarker extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); @@ -180,10 +177,15 @@ export class HandLandmarker extends VisionTaskRunner { * HandLandmarker is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected hand landmarks. */ - detect(image: ImageSource): HandLandmarkerResult { - return this.processImageData(image); + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + HandLandmarkerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); } /** @@ -193,27 +195,25 @@ export class HandLandmarker extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected hand landmarks. */ - detectForVideo(videoFrame: ImageSource, timestamp: number): - HandLandmarkerResult { - return this.processVideoData(videoFrame, timestamp); + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): HandLandmarkerResult { + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); } - /** Runs the hand landmarker graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - HandLandmarkerResult { + private resetResults(): void { this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; + } - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, IMAGE_STREAM, timestamp); - this.graphRunner.addProtoToStream( - FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', - NORM_RECT_STREAM, timestamp); - this.finishProcessing(); - + private processResults(): HandLandmarkerResult { return { landmarks: this.landmarks, worldLandmarks: this.worldLandmarks, diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 310575964..86c7d8457 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 4a2be5566..2ad4a821d 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -22,6 +22,7 @@ import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_cla import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -31,7 +32,8 @@ import {ImageClassifierResult} from './image_classifier_result'; const IMAGE_CLASSIFIER_GRAPH = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; -const INPUT_STREAM = 'input_image'; +const IMAGE_STREAM = 'input_image'; +const NORM_RECT_STREAM = 'norm_rect'; const CLASSIFICATIONS_STREAM = 'classifications'; export * from './image_classifier_options'; @@ -42,7 +44,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs classification on images. */ -export class ImageClassifier extends VisionTaskRunner { +export class ImageClassifier extends VisionTaskRunner { private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); @@ -97,7 +99,9 @@ export class ImageClassifier extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -130,10 +134,15 @@ export class ImageClassifier extends VisionTaskRunner { * ImageClassifier is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - classify(image: ImageSource): ImageClassifierResult { - return this.processImageData(image); + classify(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + ImageClassifierResult { + this.classificationResult = {classifications: []}; + this.processImageData(image, imageProcessingOptions); + return this.classificationResult; } /** @@ -143,28 +152,23 @@ export class ImageClassifier extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - classifyForVideo(videoFrame: ImageSource, timestamp: number): - ImageClassifierResult { - return this.processVideoData(videoFrame, timestamp); - } - - /** Runs the image classification graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - ImageClassifierResult { - // Get classification result by running our MediaPipe graph. + classifyForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageClassifierResult { this.classificationResult = {classifications: []}; - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return this.classificationResult; } /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -175,7 +179,8 @@ export class ImageClassifier extends VisionTaskRunner { // are built-in. const classifierNode = new CalculatorGraphConfig.Node(); classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); - classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); + classifierNode.addInputStream('IMAGE:' + IMAGE_STREAM); + classifierNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM); classifierNode.setOptions(calculatorOptions); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index de4785e6c..449cee9bb 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 4651ae4ce..64a10f5f4 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -24,6 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -31,10 +32,12 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner import {ImageEmbedderOptions} from './image_embedder_options'; import {ImageEmbedderResult} from './image_embedder_result'; + // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -const INPUT_STREAM = 'image_in'; +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; const EMBEDDINGS_STREAM = 'embeddings_out'; const TEXT_EMBEDDER_CALCULATOR = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; @@ -44,7 +47,7 @@ export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ -export class ImageEmbedder extends VisionTaskRunner { +export class ImageEmbedder extends VisionTaskRunner { private readonly options = new ImageEmbedderGraphOptions(); private embeddings: ImageEmbedderResult = {embeddings: []}; @@ -99,7 +102,9 @@ export class ImageEmbedder extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -132,10 +137,14 @@ export class ImageEmbedder extends VisionTaskRunner { * ImageEmbedder is created with running mode `image`. * * @param image The image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - embed(image: ImageSource): ImageEmbedderResult { - return this.processImageData(image); + embed(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + ImageEmbedderResult { + this.processImageData(image, imageProcessingOptions); + return this.embeddings; } /** @@ -145,11 +154,15 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param imageFrame The image frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - embedForVideo(imageFrame: ImageSource, timestamp: number): - ImageEmbedderResult { - return this.processVideoData(imageFrame, timestamp); + embedForVideo( + imageFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageEmbedderResult { + this.processVideoData(imageFrame, imageProcessingOptions, timestamp); + return this.embeddings; } /** @@ -165,16 +178,6 @@ export class ImageEmbedder extends VisionTaskRunner { return computeCosineSimilarity(u, v); } - /** Runs the embedding extraction and blocks on the response. */ - protected process(image: ImageSource, timestamp: number): - ImageEmbedderResult { - // Get embeddings by running our MediaPipe graph. - this.graphRunner.addGpuBufferAsImageToStream( - image, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); - return this.embeddings; - } - /** * Internal function for converting raw data into an embedding, and setting it * as our embeddings result. @@ -187,7 +190,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -195,7 +199,8 @@ export class ImageEmbedder extends VisionTaskRunner { const embedderNode = new CalculatorGraphConfig.Node(); embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); - embedderNode.addInputStream('IMAGE:' + INPUT_STREAM); + embedderNode.addInputStream('IMAGE:' + IMAGE_STREAM); + embedderNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); embedderNode.setOptions(calculatorOptions); diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index fc206a2d7..76fa589c8 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -23,6 +23,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index ac489ec00..3a79c1b00 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -20,6 +20,7 @@ import {Detection as DetectionProto} from '../../../../framework/formats/detecti import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -27,7 +28,8 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner import {ObjectDetectorOptions} from './object_detector_options'; import {Detection} from './object_detector_result'; -const INPUT_STREAM = 'input_frame_gpu'; +const IMAGE_STREAM = 'input_frame_gpu'; +const NORM_RECT_STREAM = 'norm_rect'; const DETECTIONS_STREAM = 'detections'; const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; @@ -41,7 +43,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs object detection on images. */ -export class ObjectDetector extends VisionTaskRunner { +export class ObjectDetector extends VisionTaskRunner { private detections: Detection[] = []; private readonly options = new ObjectDetectorOptionsProto(); @@ -96,7 +98,9 @@ export class ObjectDetector extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -160,10 +164,15 @@ export class ObjectDetector extends VisionTaskRunner { * ObjectDetector is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The list of detected objects */ - detect(image: ImageSource): Detection[] { - return this.processImageData(image); + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + Detection[] { + this.detections = []; + this.processImageData(image, imageProcessingOptions); + return [...this.detections]; } /** @@ -173,20 +182,15 @@ export class ObjectDetector extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The list of detected objects */ - detectForVideo(videoFrame: ImageSource, timestamp: number): Detection[] { - return this.processVideoData(videoFrame, timestamp); - } - - /** Runs the object detector graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - Detection[] { - // Get detections by running our MediaPipe graph. + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): Detection[] { this.detections = []; - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return [...this.detections]; } @@ -226,7 +230,8 @@ export class ObjectDetector extends VisionTaskRunner { /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -235,7 +240,8 @@ export class ObjectDetector extends VisionTaskRunner { const detectorNode = new CalculatorGraphConfig.Node(); detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH); - detectorNode.addInputStream('IMAGE:' + INPUT_STREAM); + detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM); + detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); detectorNode.setOptions(calculatorOptions); From b4ede6db7bf85071893c7edd29cde9e5d7a288f9 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:00:22 -0800 Subject: [PATCH 344/469] Fix typo in Category.java PiperOrigin-RevId: 500324008 --- .../mediapipe/tasks/components/containers/Category.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java index e955605e4..ab3fd0bd8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java @@ -19,9 +19,9 @@ import com.google.mediapipe.formats.proto.ClassificationProto; import java.util.Objects; /** - * Category is a util class, contains a category name, its display name, a float value as score, and - * the index of the label in the corresponding label file. Typically it's used as result of - * classification or detection tasks. + * Category is a util class, that contains a category name, its display name, a float value as + * score, and the index of the label in the corresponding label file. Typically it's used as result + * of classification or detection tasks. */ @AutoValue public abstract class Category { From ed0054836a62d52771520aa4f07be6b1c5ad3962 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:04:57 -0800 Subject: [PATCH 345/469] Allow task to recover after a failed graph start PiperOrigin-RevId: 500324587 --- mediapipe/tasks/web/core/task_runner.ts | 21 +++++++++++--------- mediapipe/tasks/web/core/task_runner_test.ts | 12 +++++++++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index a3df7adf5..c2679b773 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -164,16 +164,19 @@ export abstract class TaskRunner { /** Throws the error from the error listener if an error was raised. */ private handleErrors() { - const errorCount = this.processingErrors.length; - if (errorCount === 1) { - // Re-throw error to get a more meaningful stacktrace - throw new Error(this.processingErrors[0].message); - } else if (errorCount > 1) { - throw new Error( - 'Encountered multiple errors: ' + - this.processingErrors.map(e => e.message).join(', ')); + try { + const errorCount = this.processingErrors.length; + if (errorCount === 1) { + // Re-throw error to get a more meaningful stacktrace + throw new Error(this.processingErrors[0].message); + } else if (errorCount > 1) { + throw new Error( + 'Encountered multiple errors: ' + + this.processingErrors.map(e => e.message).join(', ')); + } + } finally { + this.processingErrors = []; } - this.processingErrors = []; } /** Configures the `externalFile` option */ diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index 684beb70c..9a8aa32eb 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -139,6 +139,18 @@ describe('TaskRunner', () => { }).toThrowError(/Test error 1, Test error 2/); }); + it('clears errors once thrown', () => { + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error/); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).not.toThrow(); + }); + it('verifies that at least one model asset option is provided', () => { expect(() => { taskRunner.setOptions({}); From c9ebc6fa606888542ad89b978c2658c127d4226f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:34:46 -0800 Subject: [PATCH 346/469] Use synthetic timestamps in Web when none provided PiperOrigin-RevId: 500327275 --- .../audio_classifier/audio_classifier.ts | 5 ++++- .../audio/audio_embedder/audio_embedder.ts | 18 ++++++++++++------ .../tasks/web/audio/core/audio_task_runner.ts | 5 ++++- mediapipe/tasks/web/core/task_runner.ts | 18 +++++++++++++++++- .../text/text_classifier/text_classifier.ts | 10 ++++++---- .../web/text/text_embedder/text_embedder.ts | 19 ++++++++++++------- .../web/vision/core/vision_task_runner.ts | 6 +++++- .../gesture_recognizer/gesture_recognizer.ts | 12 ++++++++---- .../vision/hand_landmarker/hand_landmarker.ts | 9 ++++++--- .../image_classifier/image_classifier.ts | 3 ++- .../vision/image_embedder/image_embedder.ts | 8 +++++--- .../vision/object_detector/object_detector.ts | 5 +++-- 12 files changed, 84 insertions(+), 34 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 92fca93ad..e26ead6a9 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -126,6 +126,8 @@ export class AudioClassifier extends AudioTaskRunner { return this.applyOptions(options); } + // TODO: Add a classifyStream() that takes a timestamp + /** * Performs audio classification on the provided audio clip and waits * synchronously for the response. @@ -194,8 +196,9 @@ export class AudioClassifier extends AudioTaskRunner { graphConfig.addNode(classifierNode); this.graphRunner.attachProtoVectorListener( - TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { + TIMESTAMPED_CLASSIFICATIONS_STREAM, (binaryProtos, timestamp) => { this.addJsAudioClassificationResults(binaryProtos); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 2e210f969..7411f95ef 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -128,6 +128,8 @@ export class AudioEmbedder extends AudioTaskRunner { return this.applyOptions(options); } + // TODO: Add a classifyStream() that takes a timestamp + /** * Performs embeding extraction on the provided audio clip and waits * synchronously for the response. @@ -193,20 +195,24 @@ export class AudioEmbedder extends AudioTaskRunner { graphConfig.addNode(embedderNode); - this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResults.push( - convertFromEmbeddingResultProto(embeddingResult)); - }); + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + this.setLatestOutputTimestamp(timestamp); + }); this.graphRunner.attachProtoVectorListener( - TIMESTAMPED_EMBEDDINGS_STREAM, data => { + TIMESTAMPED_EMBEDDINGS_STREAM, (data, timestamp) => { for (const binaryProto of data) { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResults.push( convertFromEmbeddingResultProto(embeddingResult)); } + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index 24d78378d..ff39185f2 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -36,8 +36,11 @@ export abstract class AudioTaskRunner extends TaskRunner { /** Sends a single audio clip to the graph and awaits results. */ protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; return this.process( - audioData, sampleRate ?? this.defaultSampleRate, performance.now()); + audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp); } } diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c2679b773..8d483d9ff 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -50,7 +50,7 @@ export async function createTaskRunner( } }; - // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // Initialize a canvas if requested. If OffscreenCanvas is available, we // let the graph runner initialize it by passing `undefined`. const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? document.createElement('canvas') : @@ -66,6 +66,7 @@ export async function createTaskRunner( export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; private processingErrors: Error[] = []; + private latestOutputTimestamp = 0; /** * Creates a new instance of a Mediapipe Task. Determines if SIMD is @@ -162,6 +163,21 @@ export abstract class TaskRunner { this.handleErrors(); } + /* + * Sets the latest output timestamp received from the graph (in ms). + * Timestamps that are smaller than the currently latest output timestamp are + * ignored. + */ + protected setLatestOutputTimestamp(timestamp: number): void { + this.latestOutputTimestamp = + Math.max(this.latestOutputTimestamp, timestamp); + } + + /** Returns the latest output timestamp. */ + protected getLatestOutputTimestamp() { + return this.latestOutputTimestamp; + } + /** Throws the error from the error listener if an error was raised. */ private handleErrors() { try { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 6aef1b3e4..ff314cfc3 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -131,10 +131,11 @@ export class TextClassifier extends TaskRunner { * @return The classification result of the text */ classify(text: string): TextClassifierResult { - // Get classification result by running our MediaPipe graph. + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; this.classificationResult = {classifications: []}; - this.graphRunner.addStringToStream( - text, INPUT_STREAM, /* timestamp= */ performance.now()); + this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); this.finishProcessing(); return this.classificationResult; } @@ -158,9 +159,10 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); this.graphRunner.attachProtoListener( - CLASSIFICATIONS_STREAM, binaryProto => { + CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => { this.classificationResult = convertFromClassificationResultProto( ClassificationResult.deserializeBinary(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index db7986dec..daa1d24ed 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -135,9 +135,10 @@ export class TextEmbedder extends TaskRunner { * @return The embedding resuls of the text */ embed(text: string): TextEmbedderResult { - // Get text embeddings by running our MediaPipe graph. - this.graphRunner.addStringToStream( - text, INPUT_STREAM, /* timestamp= */ performance.now()); + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; + this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); this.finishProcessing(); return this.embeddingResult; } @@ -173,10 +174,14 @@ export class TextEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); - }); + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResult = + convertFromEmbeddingResultProto(embeddingResult); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 9adc810fc..9ed9ffdb2 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -71,7 +71,11 @@ export abstract class VisionTaskRunner extends TaskRunner { 'Task is not initialized with image mode. ' + '\'runningMode\' must be set to \'image\'.'); } - this.process(image, imageProcessingOptions, performance.now()); + + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; + this.process(image, imageProcessingOptions, syntheticTimestamp); } /** Sends a single video frame to the graph and awaits results. */ diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index e0c6affcb..48efc4855 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -380,23 +380,27 @@ export class GestureRecognizer extends VisionTaskRunner { graphConfig.addNode(recognizerNode); this.graphRunner.attachProtoVectorListener( - LANDMARKS_STREAM, binaryProto => { + LANDMARKS_STREAM, (binaryProto, timestamp) => { this.addJsLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - WORLD_LANDMARKS_STREAM, binaryProto => { + WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.adddJsWorldLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - HAND_GESTURES_STREAM, binaryProto => { + HAND_GESTURES_STREAM, (binaryProto, timestamp) => { // Gesture index is not used, because the final gesture result comes // from multiple classifiers. this.gestures.push( ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - HANDEDNESS_STREAM, binaryProto => { + HANDEDNESS_STREAM, (binaryProto, timestamp) => { this.handednesses.push(...this.toJsCategories(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index e238bc96f..b51fb6a52 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -313,16 +313,19 @@ export class HandLandmarker extends VisionTaskRunner { graphConfig.addNode(landmarkerNode); this.graphRunner.attachProtoVectorListener( - LANDMARKS_STREAM, binaryProto => { + LANDMARKS_STREAM, (binaryProto, timestamp) => { this.addJsLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - WORLD_LANDMARKS_STREAM, binaryProto => { + WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.adddJsWorldLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - HANDEDNESS_STREAM, binaryProto => { + HANDEDNESS_STREAM, (binaryProto, timestamp) => { this.handednesses.push(...this.toJsCategories(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 2ad4a821d..cb2849cd8 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -187,9 +187,10 @@ export class ImageClassifier extends VisionTaskRunner { graphConfig.addNode(classifierNode); this.graphRunner.attachProtoListener( - CLASSIFICATIONS_STREAM, binaryProto => { + CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => { this.classificationResult = convertFromClassificationResultProto( ClassificationResult.deserializeBinary(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 64a10f5f4..788646e6d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -206,9 +206,11 @@ export class ImageEmbedder extends VisionTaskRunner { graphConfig.addNode(embedderNode); - this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { - this.addJsImageEmdedding(binaryProto); - }); + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + this.addJsImageEmdedding(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 3a79c1b00..5741a3a0c 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -176,7 +176,7 @@ export class ObjectDetector extends VisionTaskRunner { } /** - * Performs object detection on the provided vidoe frame and waits + * Performs object detection on the provided video frame and waits * synchronously for the response. Only use this method when the * ObjectDetector is created with running mode `video`. * @@ -248,8 +248,9 @@ export class ObjectDetector extends VisionTaskRunner { graphConfig.addNode(detectorNode); this.graphRunner.attachProtoVectorListener( - DETECTIONS_STREAM, binaryProto => { + DETECTIONS_STREAM, (binaryProto, timestamp) => { this.addJsObjectDetections(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); From 7f043b7de1f4230359c4b16e5deae58cb9ea50b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:40:09 -0800 Subject: [PATCH 347/469] Allow split_vector_calculator to be build with iOS and MEDIAPIPE_DISABLE_GPU PiperOrigin-RevId: 500327774 --- mediapipe/calculators/core/BUILD | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index b3378a74e..df54c5800 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -13,12 +13,21 @@ # limitations under the License. # +load("@bazel_skylib//lib:selects.bzl", "selects") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +selects.config_setting_group( + name = "ios_or_disable_gpu", + match_any = [ + "//mediapipe/gpu:disable_gpu", + "//mediapipe:ios", + ], +) + mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], @@ -899,8 +908,7 @@ cc_library( "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + select({ - "//mediapipe/gpu:disable_gpu": [], - "//mediapipe:ios": [], + ":ios_or_disable_gpu": [], "//conditions:default": [ "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", ], From e0a254789a1ec05f3c09411b45a6c59d0ed3075e Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Fri, 6 Jan 2023 22:13:13 -0800 Subject: [PATCH 348/469] Internal change. PiperOrigin-RevId: 500331015 --- mediapipe/framework/formats/tensor/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor/BUILD b/mediapipe/framework/formats/tensor/BUILD index c634b0dda..3895fc82e 100644 --- a/mediapipe/framework/formats/tensor/BUILD +++ b/mediapipe/framework/formats/tensor/BUILD @@ -13,7 +13,7 @@ # limitations under the License. package( - default_visibility = ["//visibility:public"], + default_visibility = ["//visibility:private"], features = ["-layering_check"], ) From 1bbe065647b30f7b457df56747b24510c225258d Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 09:11:37 -0800 Subject: [PATCH 349/469] Simplify default options for GestureRecognize PiperOrigin-RevId: 500729643 --- mediapipe/tasks/testdata/vision/BUILD | 2 + .../gesture_recognizer/gesture_recognizer.ts | 39 +++++-------------- third_party/external_files.bzl | 6 +++ 3 files changed, 17 insertions(+), 30 deletions(-) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 95b721fdb..607245700 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -38,6 +38,7 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", @@ -95,6 +96,7 @@ filegroup( "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 48efc4855..1b7201b9a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -54,7 +54,7 @@ const GESTURE_RECOGNIZER_GRAPH = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; const DEFAULT_NUM_HANDS = 1; -const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CONFIDENCE = 0.5; const DEFAULT_CATEGORY_INDEX = -1; /** Performs hand gesture recognition on images. */ @@ -143,8 +143,6 @@ export class GestureRecognizer extends VisionTaskRunner { new HandGestureRecognizerGraphOptions(); this.options.setHandGestureRecognizerGraphOptions( this.handGestureRecognizerGraphOptions); - - this.initDefaults(); } protected override get baseOptions(): BaseOptionsProto { @@ -165,22 +163,14 @@ export class GestureRecognizer extends VisionTaskRunner { * @param options The options for the gesture recognizer. */ override setOptions(options: GestureRecognizerOptions): Promise { - if ('numHands' in options) { - this.handDetectorGraphOptions.setNumHands( - options.numHands ?? DEFAULT_NUM_HANDS); - } - if ('minHandDetectionConfidence' in options) { - this.handDetectorGraphOptions.setMinDetectionConfidence( - options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minHandPresenceConfidence' in options) { - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minTrackingConfidence' in options) { - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); - } + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_CONFIDENCE); if (options.cannedGesturesClassifierOptions) { // Note that we have to support both JSPB and ProtobufJS and cannot @@ -281,17 +271,6 @@ export class GestureRecognizer extends VisionTaskRunner { } } - /** Sets the default values for the graph. */ - private initDefaults(): void { - this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); - this.handDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - DEFAULT_SCORE_THRESHOLD); - } - /** Converts the proto data to a Category[][] structure. */ private toJsCategories(data: Uint8Array[], populateIndex = true): Category[][] { diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 72ca95e66..790486676 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -286,6 +286,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/fist_landmarks.pbtxt?generation=1666999360561864"], ) + http_file( + name = "com_google_mediapipe_fist_png", + sha256 = "4397b3d3f590c88a8de7d21c08d73a0df4a97fd93f92cbd086eef37fd246daaa", + urls = ["https://storage.googleapis.com/mediapipe-assets/fist.png?generation=1672952068696274"], + ) + http_file( name = "com_google_mediapipe_general_meta_json", sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f", From 2b9299959cddc5505cb1d28fc50a2f9d46702f12 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 09:14:05 -0800 Subject: [PATCH 350/469] Internal change PiperOrigin-RevId: 500730237 --- .../web/vision/object_detector/object_detector_result.d.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts index e9e3843bc..c9c87a1bf 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts @@ -16,6 +16,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; +export {Category}; + /** An integer bounding box, axis aligned. */ export declare interface BoundingBox { /** The X coordinate of the top-left corner, in pixels. */ From c6cf598774810fdf45f325a8b5cb083884a13e6d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 9 Jan 2023 09:52:04 -0800 Subject: [PATCH 351/469] Minor fix for max_queue_size documentation PiperOrigin-RevId: 500738798 --- mediapipe/framework/calculator.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto index 7c5e8b144..eecd033c9 100644 --- a/mediapipe/framework/calculator.proto +++ b/mediapipe/framework/calculator.proto @@ -382,7 +382,7 @@ message CalculatorGraphConfig { // is empty and no other nodes are running (to prevent possible deadlocks due // to a incorrectly specified value). This global parameter is set to 100 // packets by default to enable pipelining. If any node indicates that it - // buffers packets before emitting them, then the max(node_buffer_size, + // buffers packets before emitting them, then the max(buffer_size_hint, // max_queue_size) is used. Set this parameter to -1 to disable throttling // (i.e. the graph will use as much memory as it requires). If not specified, // the limit is 100 packets. From 73f4636292b4ee65c36863a664b3dfb9e11b36a5 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 10:34:26 -0800 Subject: [PATCH 352/469] Create README.md files to NPM packages PiperOrigin-RevId: 500750516 --- mediapipe/tasks/web/BUILD | 3 ++ mediapipe/tasks/web/audio/BUILD | 2 + mediapipe/tasks/web/audio/README.md | 31 +++++++++++ mediapipe/tasks/web/text/BUILD | 2 + mediapipe/tasks/web/text/README.md | 34 ++++++++++++ mediapipe/tasks/web/vision/BUILD | 2 + mediapipe/tasks/web/vision/README.md | 78 ++++++++++++++++++++++++++++ 7 files changed, 152 insertions(+) create mode 100644 mediapipe/tasks/web/audio/README.md create mode 100644 mediapipe/tasks/web/text/README.md create mode 100644 mediapipe/tasks/web/vision/README.md diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index bc9e84147..02bd70dd0 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -65,6 +65,7 @@ pkg_npm( "wasm/audio_wasm_nosimd_internal.js", "wasm/audio_wasm_nosimd_internal.wasm", ":audio_bundle", + "//mediapipe/tasks/web/audio:README.md", ], ) @@ -108,6 +109,7 @@ pkg_npm( "wasm/text_wasm_nosimd_internal.js", "wasm/text_wasm_nosimd_internal.wasm", ":text_bundle", + "//mediapipe/tasks/web/text:README.md", ], ) @@ -151,5 +153,6 @@ pkg_npm( "wasm/vision_wasm_nosimd_internal.js", "wasm/vision_wasm_nosimd_internal.wasm", ":vision_bundle", + "//mediapipe/tasks/web/vision:README.md", ], ) diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 9d26f1118..50a611f41 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -4,6 +4,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) +exports_files(["README.md"]) + mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/audio/README.md b/mediapipe/tasks/web/audio/README.md new file mode 100644 index 000000000..834785709 --- /dev/null +++ b/mediapipe/tasks/web/audio/README.md @@ -0,0 +1,31 @@ +# MediaPipe Tasks Vision Package + +This package contains the audio tasks for MediaPipe. + +## Audio Classification + +The MediaPipe Audio Classification task performs classification on audio data. + +``` +const audio = await FilesetResolver.forAudioTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-audio@latest/wasm" +); +const audioClassifier = await AudioClassifier.createFromModelPath(audio, + "https://storage.googleapis.com/mediapipe-tasks/audio_classifier/yamnet_audio_classifier_with_metadata.tflite" +); +const classifications = audioClassifier.classifiy(audioData); +``` + +## Audio Embedding + +The MediaPipe Audio Embedding task extracts embeddings from audio data. + +``` +const audio = await FilesetResolver.forAudioTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-audio@latest/wasm" +); +const audioEmbedder = await AudioEmbedder.createFromModelPath(audio, + "model.tflite" +); +const embeddings = audioEmbedder.embed(audioData); +``` diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 32f43d4b6..077b25645 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -4,6 +4,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) +exports_files(["README.md"]) + mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/text/README.md b/mediapipe/tasks/web/text/README.md new file mode 100644 index 000000000..247dc6d30 --- /dev/null +++ b/mediapipe/tasks/web/text/README.md @@ -0,0 +1,34 @@ +# MediaPipe Tasks Text Package + +This package contains the text tasks for MediaPipe. + +## Text Classification + +MediaPipe Text Classifier task lets you classify text into a set of defined +categories, such as positive or negative sentiment. + +``` +const text = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" +); +const textClassifier = await TextClassifier.createFromModelPath(text, + "https://storage.googleapis.com/mediapipe-tasks/text_classifier/bert_text_classifier.tflite" +); +const classifications = textClassifier.classifiy(textData); +``` + +For more information, refer to the [Text Classification](https://developers.google.com/mediapipe/solutions/text/text_classifier/web_js) documentation. + +## Text Embedding + +The MediaPipe Text Embedding task extracts embeddings from text data. + +``` +const text = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" +); +const textEmbedder = await TextEmbedder.createFromModelPath(text, + "https://storage.googleapis.com/mediapipe-tasks/text_embedder/mobilebert_embedding_with_metadata.tflite" +); +const embeddings = textEmbedder.embed(textData); +``` diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 93493e873..ea022e900 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -4,6 +4,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) +exports_files(["README.md"]) + mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md new file mode 100644 index 000000000..51f43821c --- /dev/null +++ b/mediapipe/tasks/web/vision/README.md @@ -0,0 +1,78 @@ +# MediaPipe Tasks Vision Package + +This package contains the vision tasks for MediaPipe. + +## Object Detection + +The MediaPipe Object Detector task lets you detect the presence and location of +multiple classes of objects within images or videos. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const objectDetector = await ObjectDetector.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const detections = objectDetector.detect(image); +``` + +For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. + +## Image Classification + +The MediaPipe Image Classifier task lets you perform classification on images. +You can use this task to identify what an image represents among a set of +categories defined at training time. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const imageClassifier = await ImageClassifier.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/image_classifier/efficientnet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const classifications = imageClassifier.classify(image); +``` + +For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation. + +## Gesture Recognition + +The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real +time, and provides the recognized hand gesture results along with the landmarks +of the detected hands. You can use this task to recognize specific hand gestures +from a user, and invoke application features that correspond to those gestures. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const recognitions = gestureRecognizer.recognize(image); +``` + +## Handlandmark Detection + +The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in +an image. You can use this Task to localize key points of the hands and render +visual effects over the hands. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const handLandmarker = await HandLandmarker.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const landmarks = handLandmarker.detect(image); +``` + +For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation. + From d40fa6b16d9e14cf0ac7ff30efa45eef588567d5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 9 Jan 2023 11:02:48 -0800 Subject: [PATCH 353/469] Internal Model Maker change. PiperOrigin-RevId: 500758488 --- .../python/core/tasks/classifier.py | 16 ++- .../python/core/utils/model_util.py | 4 +- .../python/vision/image_classifier/BUILD | 10 -- .../vision/image_classifier/__init__.py | 1 - .../image_classifier/image_classifier.py | 96 +++++++++-------- .../train_image_classifier_lib.py | 102 ------------------ 6 files changed, 67 insertions(+), 162 deletions(-) delete mode 100644 mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index f376edffa..0908dddf5 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -48,11 +48,12 @@ class Classifier(custom_model.CustomModel): self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None - # TODO: Integrate this into all Model Maker tasks. + # TODO: Integrate this into GestureRecognizer. def _train_model(self, train_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset, - preprocessor: Optional[Callable[..., bool]] = None): + preprocessor: Optional[Callable[..., bool]] = None, + checkpoint_path: Optional[str] = None): """Trains the classifier model. Compiles and fits the tf.keras `_model` and records the `_history`. @@ -62,6 +63,9 @@ class Classifier(custom_model.CustomModel): validation_data: Validation data. preprocessor: An optional data preprocessor that can be used when generating a tf.data.Dataset. + checkpoint_path: An optional directory for the checkpoint file to support + continual training. If provided, loads model weights from the latest + checkpoint in the directory. """ tf.compat.v1.logging.info('Training the models...') if len(train_data) < self._hparams.batch_size: @@ -88,6 +92,14 @@ class Classifier(custom_model.CustomModel): optimizer=self._optimizer, loss=self._loss_function, metrics=[self._metric_function]) + + latest_checkpoint = ( + tf.train.latest_checkpoint(checkpoint_path) + if checkpoint_path else None) + if latest_checkpoint: + print(f'Resuming from {latest_checkpoint}') + self._model.load_weights(latest_checkpoint) + self._history = self._model.fit( x=train_dataset, epochs=self._hparams.epochs, diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index f10d9390c..db02444df 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -42,7 +42,9 @@ def get_default_callbacks( checkpoint_path = os.path.join(export_dir, 'checkpoint') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - checkpoint_path, save_weights_only=True) + os.path.join(checkpoint_path, 'model-{epoch:04d}'), + save_weights_only=True, + period=5) return [summary_callback, checkpoint_callback] diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index d7c47a359..bd916a92b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -87,15 +87,6 @@ py_library( ], ) -py_library( - name = "train_image_classifier_lib", - srcs = ["train_image_classifier_lib.py"], - deps = [ - ":hyperparameters", - "//mediapipe/model_maker/python/core/utils:model_util", - ], -) - py_library( name = "image_classifier", srcs = ["image_classifier.py"], @@ -104,7 +95,6 @@ py_library( ":image_classifier_options", ":model_options", ":model_spec", - ":train_image_classifier_lib", "//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/utils:model_util", diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py index 0f964ef66..4cde9e7e3 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/__init__.py +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -35,4 +35,3 @@ del image_classifier del image_classifier_options del model_options del model_spec -del train_image_classifier_lib # pylint: disable=undefined-variable diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index df71a8fef..c2181121c 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -28,7 +28,6 @@ from mediapipe.model_maker.python.vision.image_classifier import hyperparameters from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms -from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer @@ -57,6 +56,10 @@ class ImageClassifier(classifier.Classifier): mean_rgb=self._model_spec.mean_rgb, stddev_rgb=self._model_spec.stddev_rgb, use_augmentation=hparams.do_data_augmentation) + self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) + self._loss_function = tf.keras.losses.CategoricalCrossentropy( + label_smoothing=self._hparams.label_smoothing) + self._metric_function = 'accuracy' self._history = None # Training history returned from `keras_model.fit`. @classmethod @@ -66,7 +69,7 @@ class ImageClassifier(classifier.Classifier): validation_data: classification_ds.ClassificationDataset, options: image_classifier_options.ImageClassifierOptions, ) -> 'ImageClassifier': - """Creates and trains an image classifier. + """Creates and trains an ImageClassifier. Loads data and trains the model based on data for image classification. If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/ @@ -93,58 +96,29 @@ class ImageClassifier(classifier.Classifier): label_names=train_data.label_names, hparams=options.hparams, model_options=options.model_options) - - image_classifier._create_model() - - tf.compat.v1.logging.info('Training the models...') - image_classifier._train( - train_data=train_data, validation_data=validation_data) - + image_classifier._create_and_train_model(train_data, validation_data) return image_classifier - # TODO: Migrate to the shared training library of Model Maker. - def _train(self, train_data: classification_ds.ClassificationDataset, - validation_data: classification_ds.ClassificationDataset): - """Trains the model with input train_data. - - The training results are recorded by a self._history object returned by - tf.keras.Model.fit(). + def _create_and_train_model( + self, train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset): + """Creates and trains the model and optimizer. Args: train_data: Training data. validation_data: Validation data. """ - - tf.compat.v1.logging.info('Training the models...') - hparams = self._hparams - if len(train_data) < hparams.batch_size: - raise ValueError('The size of the train_data (%d) couldn\'t be smaller ' - 'than batch_size (%d). To solve this problem, set ' - 'the batch_size smaller or increase the size of the ' - 'train_data.' % (len(train_data), hparams.batch_size)) - - train_dataset = train_data.gen_tf_dataset( - batch_size=hparams.batch_size, - is_training=True, - shuffle=self._shuffle, - preprocess=self._preprocess) - hparams.steps_per_epoch = model_util.get_steps_per_epoch( - steps_per_epoch=hparams.steps_per_epoch, - batch_size=hparams.batch_size, + self._create_model() + self._hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, train_data=train_data) - train_dataset = train_dataset.take(count=hparams.steps_per_epoch) - - validation_dataset = validation_data.gen_tf_dataset( - batch_size=hparams.batch_size, - is_training=False, - preprocess=self._preprocess) - - # Train the model. - self._history = train_image_classifier_lib.train_model( - model=self._model, - hparams=hparams, - train_ds=train_dataset, - validation_ds=validation_dataset) + self._optimizer = self._create_optimizer() + self._train_model( + train_data=train_data, + validation_data=validation_data, + preprocessor=self._preprocess, + checkpoint_path=os.path.join(self._hparams.export_dir, 'checkpoint')) def _create_model(self): """Creates the classifier model from TFHub pretrained models.""" @@ -198,3 +172,33 @@ class ImageClassifier(classifier.Classifier): model_util.save_tflite(tflite_model_with_metadata, tflite_file) with open(metadata_file, 'w') as f: f.write(metadata_json) + + def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: + """Creates an optimizer with learning rate schedule. + + Uses Keras CosineDecay schedule for the learning rate by default. + + Returns: + A tf.keras.optimizers.Optimizer for model training. + """ + # Learning rate is linear to batch size. + init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 + + # Get decay steps. + total_training_steps = self._hparams.steps_per_epoch * self._hparams.epochs + default_decay_steps = ( + self._hparams.decay_samples // self._hparams.batch_size) + decay_steps = max(total_training_steps, default_decay_steps) + + learning_rate_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) + warmup_steps = self._hparams.warmup_epochs * self._hparams.steps_per_epoch + if warmup_steps: + learning_rate_fn = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate_fn, + warmup_steps=warmup_steps) + optimizer = tf.keras.optimizers.RMSprop( + learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) + + return optimizer diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py deleted file mode 100644 index c5b28cff5..000000000 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Library to train model.""" - -import os -import tensorflow as tf - -from mediapipe.model_maker.python.core.utils import model_util -from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp - - -def _create_optimizer(init_lr: float, decay_steps: int, - warmup_steps: int) -> tf.keras.optimizers.Optimizer: - """Creates an optimizer with learning rate schedule. - - Uses Keras CosineDecay schedule for the learning rate by default. - - Args: - init_lr: Initial learning rate. - decay_steps: Number of steps to decay over. - warmup_steps: Number of steps to do warmup for. - - Returns: - A tf.keras.optimizers.Optimizer for model training. - """ - learning_rate_fn = tf.keras.experimental.CosineDecay( - initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) - if warmup_steps: - learning_rate_fn = model_util.WarmUp( - initial_learning_rate=init_lr, - decay_schedule_fn=learning_rate_fn, - warmup_steps=warmup_steps) - optimizer = tf.keras.optimizers.RMSprop( - learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) - - return optimizer - - -def train_model(model: tf.keras.Model, hparams: hp.HParams, - train_ds: tf.data.Dataset, - validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: - """Trains model with the input data and hyperparameters. - - Args: - model: Input tf.keras.Model. - hparams: Hyperparameters for training image classifier. - train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). - validation_ds: tf.data.Dataset, validation data to be fed in - tf.keras.Model.fit(). - - Returns: - The tf.keras.callbacks.History object returned by tf.keras.Model.fit(). - """ - - # Learning rate is linear to batch size. - learning_rate = hparams.learning_rate * hparams.batch_size / 256 - - # Get decay steps. - # NOMUTANTS--(b/256493858):Plan to test it in the unified training library. - total_training_steps = hparams.steps_per_epoch * hparams.epochs - default_decay_steps = hparams.decay_samples // hparams.batch_size - decay_steps = max(total_training_steps, default_decay_steps) - - warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch - optimizer = _create_optimizer( - init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps) - - loss = tf.keras.losses.CategoricalCrossentropy( - label_smoothing=hparams.label_smoothing) - model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) - - summary_dir = os.path.join(hparams.export_dir, 'summaries') - summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) - # Save checkpoint every 5 epochs. - checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint') - checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(checkpoint_path, 'model-{epoch:04d}'), - save_weights_only=True, - period=5) - - latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) - if latest_checkpoint: - print(f'Resuming from {latest_checkpoint}') - model.load_weights(latest_checkpoint) - - # Train the model. - return model.fit( - x=train_ds, - epochs=hparams.epochs, - validation_data=validation_ds, - callbacks=[summary_callback, checkpoint_callback]) From 08310231145b3c82e3d72effa49e081960b6be58 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 11:09:28 -0800 Subject: [PATCH 354/469] Use uppercase enum constants for RunningMode PiperOrigin-RevId: 500760402 --- .../tasks/web/vision/core/vision_task_options.d.ts | 2 +- .../web/vision/core/vision_task_runner.test.ts | 14 +++++++------- .../tasks/web/vision/core/vision_task_runner.ts | 6 +++--- .../vision/image_embedder/image_embedder_test.ts | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 76c0177a0..44b1660ff 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -21,7 +21,7 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options' * 1) The image mode for processing single image inputs. * 2) The video mode for processing decoded frames of a video. */ -export type RunningMode = 'image'|'video'; +export type RunningMode = 'IMAGE'|'VIDEO'; /** The options for configuring a MediaPipe vision task. */ export declare interface VisionTaskOptions extends TaskRunnerOptions { diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index a48381038..4567134b8 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -118,19 +118,19 @@ describe('VisionTaskRunner', () => { }); it('can enable image mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'image'}); + await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); // Clear running mode await visionTaskRunner.setOptions( @@ -140,7 +140,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process images with video mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(() => { visionTaskRunner.processImageData( IMAGE, /* imageProcessingOptions= */ undefined); @@ -155,7 +155,7 @@ describe('VisionTaskRunner', () => { }).toThrowError(/Task is not initialized with video mode./); // Explicitly set to image mode - await visionTaskRunner.setOptions({runningMode: 'image'}); + await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); expect(() => { visionTaskRunner.processVideoData( IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); @@ -163,7 +163,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); visionTaskRunner.expectNormalizedRect(0.5, 0.5, 1, 1); @@ -172,7 +172,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph with image processing options', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); visionTaskRunner.expectNormalizedRect(0.3, 0.6, 0.2, 0.4); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 9ed9ffdb2..71cac920c 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -56,7 +56,7 @@ export abstract class VisionTaskRunner extends TaskRunner { override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { const useStreamMode = - !!options.runningMode && options.runningMode !== 'image'; + !!options.runningMode && options.runningMode !== 'IMAGE'; this.baseOptions.setUseStreamMode(useStreamMode); } return super.applyOptions(options); @@ -69,7 +69,7 @@ export abstract class VisionTaskRunner extends TaskRunner { if (!!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with image mode. ' + - '\'runningMode\' must be set to \'image\'.'); + '\'runningMode\' must be set to \'IMAGE\'.'); } // Increment the timestamp by 1 millisecond to guarantee that we send @@ -86,7 +86,7 @@ export abstract class VisionTaskRunner extends TaskRunner { if (!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with video mode. ' + - '\'runningMode\' must be set to \'video\'.'); + '\'runningMode\' must be set to \'VIDEO\'.'); } this.process(imageFrame, imageProcessingOptions, timestamp); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index 01ec751e3..5a8293c44 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -143,7 +143,7 @@ describe('ImageEmbedder', () => { }); it('for video mode', async () => { - await imageEmbedder.setOptions({runningMode: 'video'}); + await imageEmbedder.setOptions({runningMode: 'VIDEO'}); // Invoke the video embedder const embeddingResult = From 704964be33d737c44e6154fde410b363df161e73 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 9 Jan 2023 14:03:42 -0800 Subject: [PATCH 355/469] Fix accidental suppressions of GLSL linker error reporting PiperOrigin-RevId: 500802177 --- mediapipe/gpu/shader_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/gpu/shader_util.cc b/mediapipe/gpu/shader_util.cc index 2132cbda9..5de7e24f5 100644 --- a/mediapipe/gpu/shader_util.cc +++ b/mediapipe/gpu/shader_util.cc @@ -140,7 +140,7 @@ GLint GlhCreateProgram(const GLchar* vert_src, const GLchar* frag_src, glBindAttribLocation(*program, attr_locations[i], attr_names[i]); } - ok = GlhLinkProgram(*program); + ok = GlhLinkProgram(*program, force_log_errors); } if (vert_shader) glDeleteShader(vert_shader); From 76a7c9d5d488eb1c661bd6cb219eba35f7cd07ed Mon Sep 17 00:00:00 2001 From: Liam Miller-Cushon Date: Mon, 9 Jan 2023 14:47:21 -0800 Subject: [PATCH 356/469] Internal change PiperOrigin-RevId: 500813290 --- .../android/solutions/gradle/wrapper/gradle-wrapper.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 41dfb8790..070cb702f 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists From d7ee875356012514d8d5287a360cb8ea391ad0b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 16:15:52 -0800 Subject: [PATCH 357/469] Fix spacing issue in test name PiperOrigin-RevId: 500833769 --- .../web/vision/gesture_recognizer/gesture_recognizer_test.ts | 2 +- .../tasks/web/vision/hand_landmarker/hand_landmarker_test.ts | 2 +- .../tasks/web/vision/object_detector/object_detector_test.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index 3699033b2..dfc252eb6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -147,7 +147,7 @@ describe('GestureRecognizer', () => { ]); }); - describe('setOptions() ', () => { + describe('setOptions()', () => { interface TestCase { optionPath: [keyof GestureRecognizerOptions, ...string[]]; fieldPath: string[]; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index bce0eac02..0abd1df27 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -129,7 +129,7 @@ describe('HandLandmarker', () => { ]); }); - describe('setOptions() ', () => { + describe('setOptions()', () => { interface TestCase { optionPath: [keyof HandLandmarkerOptions, ...string[]]; fieldPath: string[]; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index 5bfb74ab6..ceb96acb1 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -111,7 +111,7 @@ describe('ObjectDetector', () => { verifyGraph(objectDetector, ['displayNamesLocale', 'en']); }); - describe('setOptions() ', () => { + describe('setOptions()', () => { interface TestCase { optionName: keyof ObjectDetectorOptions; protoName: string; From 6032604f94208bf9649a97a564046984ac538819 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 10 Jan 2023 08:42:07 -0800 Subject: [PATCH 358/469] Hide base task api classes for MediaPipe Tasks Python from API docs PiperOrigin-RevId: 501004802 --- mediapipe/tasks/python/audio/core/base_audio_task_api.py | 3 +-- mediapipe/tasks/python/core/BUILD | 1 + mediapipe/tasks/python/core/task_info.py | 2 ++ mediapipe/tasks/python/text/core/base_text_task_api.py | 3 +-- mediapipe/tasks/python/vision/core/base_vision_task_api.py | 3 +-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index b2197c142..5b08a2b76 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -29,6 +29,7 @@ _RunningMode = running_mode_module.AudioTaskRunningMode _Timestamp = timestamp_module.Timestamp +@doc_controls.do_not_generate_docs class BaseAudioTaskApi(object): """The base class of the user-facing mediapipe audio task api classes.""" @@ -133,12 +134,10 @@ class BaseAudioTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Return `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe audio task instance on exit of the context manager. diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index f14d59b99..6098fb5f5 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -43,6 +43,7 @@ py_library( name = "task_info", srcs = ["task_info.py"], deps = [ + ":optional_dependencies", "//mediapipe/calculators/core:flow_limiter_calculator_py_pb2", "//mediapipe/framework:calculator_options_py_pb2", "//mediapipe/framework:calculator_py_pb2", diff --git a/mediapipe/tasks/python/core/task_info.py b/mediapipe/tasks/python/core/task_info.py index 31605480f..6ea2cee7b 100644 --- a/mediapipe/tasks/python/core/task_info.py +++ b/mediapipe/tasks/python/core/task_info.py @@ -20,8 +20,10 @@ from typing import Any, List from mediapipe.calculators.core import flow_limiter_calculator_pb2 from mediapipe.framework import calculator_options_pb2 from mediapipe.framework import calculator_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +@doc_controls.do_not_generate_docs @dataclasses.dataclass class TaskInfo: """Specifications of a MediaPipe task graph. diff --git a/mediapipe/tasks/python/text/core/base_text_task_api.py b/mediapipe/tasks/python/text/core/base_text_task_api.py index b22bfff00..1d6311561 100644 --- a/mediapipe/tasks/python/text/core/base_text_task_api.py +++ b/mediapipe/tasks/python/text/core/base_text_task_api.py @@ -20,6 +20,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls _TaskRunner = task_runner.TaskRunner +@doc_controls.do_not_generate_docs class BaseTextTaskApi(object): """The base class of the user-facing mediapipe text task api classes.""" @@ -40,12 +41,10 @@ class BaseTextTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Returns `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe text task instance on exit of the context manager. diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index 016170398..0c8262d4b 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -31,6 +31,7 @@ _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +@doc_controls.do_not_generate_docs class BaseVisionTaskApi(object): """The base class of the user-facing mediapipe vision task api classes.""" @@ -178,12 +179,10 @@ class BaseVisionTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Return `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe vision task instance on exit of the context manager. From 25abd122b338de4598edc72987bd91a13104c84d Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 10 Jan 2023 09:44:04 -0800 Subject: [PATCH 359/469] Support AudioRecord in MediaPipe audio tasks in Java. PiperOrigin-RevId: 501019327 --- .../tasks/audio/core/BaseAudioTaskApi.java | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 2782f8d36..7abde72d5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -14,6 +14,9 @@ package com.google.mediapipe.tasks.audio.core; +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.media.MediaRecorder; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.tasks.components.containers.AudioData; @@ -149,4 +152,71 @@ public class BaseAudioTaskApi implements AutoCloseable { public void close() { runner.close(); } + + /** + * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned + * AudioRecord instance is initialized and client needs to call {@link + * android.media.AudioRecord#startRecording} method to start recording. + * + *

Note that MediaPipe Audio tasks will up/down sample automatically to fit the sample rate + * required by the model. The default sample rate of the MediaPipe pretrained audio model, Yamnet, + * is 16kHz. + * + * @param numChannels the number of audio channels. + * @param sampleRate the audio sample rate. + * @return an {@link android.media.AudioRecord} instance in {@link + * android.media.AudioRecord#STATE_INITIALIZED} + * @throws IllegalArgumentException if the model required channel count is unsupported + * @throws IllegalStateException if AudioRecord instance failed to initialize + */ + public static AudioRecord createAudioRecord(int numChannels, int sampleRate) { + int channelConfig = 0; + switch (numChannels) { + case 1: + channelConfig = AudioFormat.CHANNEL_IN_MONO; + break; + case 2: + channelConfig = AudioFormat.CHANNEL_IN_STEREO; + break; + default: + throw new IllegalArgumentException( + "getAudioRecord method only supports 1 or 2 audio channels."); + } + + int bufferSizeInBytes = + AudioRecord.getMinBufferSize(sampleRate, channelConfig, AudioFormat.ENCODING_PCM_FLOAT); + if (bufferSizeInBytes == AudioRecord.ERROR + || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) { + throw new IllegalStateException( + String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes)); + } + AudioRecord audioRecord = + new AudioRecord( + // including MIC, UNPROCESSED, and CAMCORDER. + MediaRecorder.AudioSource.VOICE_RECOGNITION, + sampleRate, + channelConfig, + AudioFormat.ENCODING_PCM_FLOAT, + bufferSizeInBytes); + if (audioRecord.getState() != AudioRecord.STATE_INITIALIZED) { + throw new IllegalStateException(String.format("AudioRecordfailed to initialize")); + } + return audioRecord; + } + + /** + * Creates an {@link android.media.AudioRecord} instance to record audio stream that has mono + * channel at sample rate at sample rate 16kHz, the sample rate required for models like Yamnet. + * The returned AudioRecord instance is initialized and client needs to call {@link + * android.media.AudioRecord#startRecording} method to start recording. + * + * @return an {@link android.media.AudioRecord} instance in {@link + * android.media.AudioRecord#STATE_INITIALIZED} + * @throws IllegalArgumentException if the model required channel count is unsupported + * @throws IllegalStateException if AudioRecord instance failed to initialize + */ + public static AudioRecord createAudioRecord() { + // TODO: Support creating AudioRecord based on the model specifications. + return createAudioRecord(1, 16000); + } } From 54268594dd8d6aa75222a408cc03a049b82be467 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Tue, 10 Jan 2023 17:35:57 -0800 Subject: [PATCH 360/469] Internal change. PiperOrigin-RevId: 501136760 --- .../formats/tensor/cpu_buffer_converters.cc | 240 +++++++++++++++ .../tensor/cpu_buffer_converters_test.cc | 282 ++++++++++++++++++ 2 files changed, 522 insertions(+) create mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters.cc create mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc new file mode 100644 index 000000000..e4e705be5 --- /dev/null +++ b/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc @@ -0,0 +1,240 @@ +#include +#include +#include + +#include "mediapipe/framework/formats/tensor/backend.h" +#include "mediapipe/framework/formats/tensor/tensor2.h" +#include "mediapipe/framework/formats/tensor/views/buffer.h" +#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" +#include "third_party/FP16/include/fp16.h" + +namespace mediapipe { +namespace { + +template +auto ConverterCheckFunction() { + return + [](const Tensor2& tensor, uint64_t source_descriptor_type_id, + const Tensor2::ViewDescriptor& source_base_descriptor, + uint64_t destination_descriptor_type_id, + const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { + if (source_descriptor_type_id != TensorCpuView::kId || + destination_descriptor_type_id != TensorCpuView::kId) + return false; + auto source_descriptor = + static_cast(source_base_descriptor); + auto destination_descriptor = + static_cast( + destination_base_descriptor); + return source_descriptor.buffer.format == + TensorTypeToFormat::value && + destination_descriptor.buffer.format == + TensorTypeToFormat::value; + }; +} + +template +auto ConvertFunction() { + return [](const Tensor2& tensor, const Tensor2::View& source_base_view, + const Tensor2::View& destination_base_view) -> bool { + auto source = source_base_view.DownCast(); + auto destination = destination_base_view.DownCast(); + if (source->descriptor().buffer.format == + destination->descriptor().buffer.format) { + std::memcpy( + destination->data(), source->data(), + TensorBufferSize(destination->descriptor().buffer, tensor.shape())); + } else { + auto source_pointer = source->data(); + auto destination_pointer = destination->data(); + for (int i = 0; i < tensor.shape().NumElements(); i++) { + *destination_pointer++ = + GpuLikeTypeCast(*source_pointer++); + } + } + return true; + }; +} + +#define REGISTER_CONVERTER(SourceType, DestinationType) \ + TENSOR_REGISTER_CONVERTER( \ + {ConverterCheckFunction(), \ + ConvertFunction()}); + +REGISTER_CONVERTER(float, Float16); +REGISTER_CONVERTER(float, int8_t); +REGISTER_CONVERTER(float, uint8_t); +REGISTER_CONVERTER(float, int16_t); +REGISTER_CONVERTER(float, uint16_t); +REGISTER_CONVERTER(float, int32_t); +REGISTER_CONVERTER(float, uint32_t); + +REGISTER_CONVERTER(Float16, float); +REGISTER_CONVERTER(Float16, int8_t); +REGISTER_CONVERTER(Float16, uint8_t); +REGISTER_CONVERTER(Float16, int16_t); +REGISTER_CONVERTER(Float16, uint16_t); +REGISTER_CONVERTER(Float16, int32_t); +REGISTER_CONVERTER(Float16, uint32_t); + +REGISTER_CONVERTER(int8_t, float); +REGISTER_CONVERTER(int8_t, Float16); +REGISTER_CONVERTER(int8_t, uint8_t); +REGISTER_CONVERTER(int8_t, int16_t); +REGISTER_CONVERTER(int8_t, uint16_t); +REGISTER_CONVERTER(int8_t, int32_t); +REGISTER_CONVERTER(int8_t, uint32_t); + +REGISTER_CONVERTER(uint8_t, float); +REGISTER_CONVERTER(uint8_t, Float16); +REGISTER_CONVERTER(uint8_t, int8_t); +REGISTER_CONVERTER(uint8_t, int16_t); +REGISTER_CONVERTER(uint8_t, uint16_t); +REGISTER_CONVERTER(uint8_t, int32_t); +REGISTER_CONVERTER(uint8_t, uint32_t); + +REGISTER_CONVERTER(int16_t, float); +REGISTER_CONVERTER(int16_t, Float16); +REGISTER_CONVERTER(int16_t, int8_t); +REGISTER_CONVERTER(int16_t, uint8_t); +REGISTER_CONVERTER(int16_t, uint16_t); +REGISTER_CONVERTER(int16_t, uint32_t); +REGISTER_CONVERTER(int16_t, uint32_t); + +REGISTER_CONVERTER(uint16_t, float); +REGISTER_CONVERTER(uint16_t, Float16); +REGISTER_CONVERTER(uint16_t, int8_t); +REGISTER_CONVERTER(uint16_t, uint8_t); +REGISTER_CONVERTER(uint16_t, int16_t); +REGISTER_CONVERTER(uint16_t, int32_t); +REGISTER_CONVERTER(uint16_t, uint32_t); + +REGISTER_CONVERTER(int32_t, float); +REGISTER_CONVERTER(int32_t, Float16); +REGISTER_CONVERTER(int32_t, int8_t); +REGISTER_CONVERTER(int32_t, uint8_t); +REGISTER_CONVERTER(int32_t, int16_t); +REGISTER_CONVERTER(int32_t, uint16_t); +REGISTER_CONVERTER(int32_t, uint32_t); + +REGISTER_CONVERTER(uint32_t, float); +REGISTER_CONVERTER(uint32_t, Float16); +REGISTER_CONVERTER(uint32_t, int8_t); +REGISTER_CONVERTER(uint32_t, uint8_t); +REGISTER_CONVERTER(uint32_t, int16_t); +REGISTER_CONVERTER(uint32_t, uint16_t); +REGISTER_CONVERTER(uint32_t, int32_t); + +template +auto DequantizationCheckFunction() { + return + [](const Tensor2& tensor, uint64_t source_descriptor_type_id, + const Tensor2::ViewDescriptor& source_base_descriptor, + uint64_t destination_descriptor_type_id, + const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { + if (source_descriptor_type_id != TensorCpuView::kId || + destination_descriptor_type_id != TensorCpuView::kId) + return false; + auto source_descriptor = + static_cast(source_base_descriptor); + auto destination_descriptor = + static_cast( + destination_base_descriptor); + return source_descriptor.buffer.format == + TensorBufferDescriptor::Format::kQuantizedInt8 && + destination_descriptor.buffer.format == + TensorTypeToFormat::value; + }; +} + +template +auto DequantizationConvertFunction() { + return [](const Tensor2& tensor, const Tensor2::View& source_base_view, + const Tensor2::View& destination_base_view) -> bool { + auto source = source_base_view.DownCast(); + auto destination = destination_base_view.DownCast(); + auto source_pointer = source->data(); + auto destination_pointer = destination->data(); + int zero_point = + source->descriptor().buffer.quantization_parameters.zero_point; + float scale = source->descriptor().buffer.quantization_parameters.scale; + for (int i = 0; i < tensor.shape().NumElements(); i++) { + *destination_pointer++ = static_cast( + (*source_pointer++ - zero_point) * scale); + } + return true; + }; +} + +#define REGISTER_DEQUANTIZATION_CONVERTER(DestinationType) \ + TENSOR_REGISTER_CONVERTER( \ + {DequantizationCheckFunction(), \ + DequantizationConvertFunction()}); + +REGISTER_DEQUANTIZATION_CONVERTER(float); +REGISTER_DEQUANTIZATION_CONVERTER(Float16); +REGISTER_DEQUANTIZATION_CONVERTER(int8_t); +REGISTER_DEQUANTIZATION_CONVERTER(uint8_t); +REGISTER_DEQUANTIZATION_CONVERTER(int16_t); +REGISTER_DEQUANTIZATION_CONVERTER(uint16_t); +REGISTER_DEQUANTIZATION_CONVERTER(int32_t); +REGISTER_DEQUANTIZATION_CONVERTER(uint32_t); + +template +auto QuantizationCheckFunction() { + return + [](const Tensor2& tensor, uint64_t source_descriptor_type_id, + const Tensor2::ViewDescriptor& source_base_descriptor, + uint64_t destination_descriptor_type_id, + const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { + if (source_descriptor_type_id != TensorCpuView::kId || + destination_descriptor_type_id != TensorCpuView::kId) + return false; + auto source_descriptor = + static_cast(source_base_descriptor); + auto destination_descriptor = + static_cast( + destination_base_descriptor); + bool same = source_descriptor.buffer.format == + TensorTypeToFormat::value && + destination_descriptor.buffer.format == + TensorBufferDescriptor::Format::kQuantizedInt8; + return same; + }; +} + +template +auto QuantizationConvertFunction() { + return [](const Tensor2& tensor, const Tensor2::View& source_base_view, + const Tensor2::View& destination_base_view) -> bool { + auto source = source_base_view.DownCast(); + auto destination = destination_base_view.DownCast(); + auto source_pointer = source->data(); + auto destination_pointer = destination->data(); + int zero_point = + destination->descriptor().buffer.quantization_parameters.zero_point; + float scale = + destination->descriptor().buffer.quantization_parameters.scale; + for (int i = 0; i < tensor.shape().NumElements(); i++) { + *destination_pointer++ = + static_cast(*source_pointer++ / scale + zero_point); + } + return true; + }; +} + +#define REGISTER_QUANTIZATION_CONVERTER(SourceType) \ + TENSOR_REGISTER_CONVERTER({QuantizationCheckFunction(), \ + QuantizationConvertFunction()}); + +REGISTER_QUANTIZATION_CONVERTER(float); +REGISTER_QUANTIZATION_CONVERTER(Float16); +REGISTER_QUANTIZATION_CONVERTER(int8_t); +REGISTER_QUANTIZATION_CONVERTER(uint8_t); +REGISTER_QUANTIZATION_CONVERTER(int16_t); +REGISTER_QUANTIZATION_CONVERTER(uint16_t); +REGISTER_QUANTIZATION_CONVERTER(int32_t); +REGISTER_QUANTIZATION_CONVERTER(uint32_t); + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc new file mode 100644 index 000000000..3619ad531 --- /dev/null +++ b/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc @@ -0,0 +1,282 @@ +#include + +#include "mediapipe/framework/formats/tensor/tensor2.h" +#include "mediapipe/framework/formats/tensor/views/buffer.h" +#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +MATCHER_P(NearWithPrecision, precision, "") { + return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; +} +MATCHER_P(IntegerEqual, precision, "") { + return std::get<0>(arg) == std::get<1>(arg); +} + +namespace mediapipe { + +TEST(TensorCpuViewTest, TestWrite32ThenRead16) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 1234.0f; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat16}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 1234.0f); + } +} + +TEST(TensorCpuViewTest, TestWrite16ThenRead32) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat16}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 1234.0f; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 1234.0f); + } +} + +TEST(TensorCpuViewTest, TestWriteFloat32ThenReadInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 0.121569f; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ( + *view->data(), + static_cast(0.121569f * std::numeric_limits::max())); + } +} + +TEST(TensorCpuViewTest, TestWriteInt8ThenReadFloat32) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = + static_cast(0.123f * std::numeric_limits::max()); + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_NEAR(*view->data(), 0.123f, + 1.0f / std::numeric_limits::max()); + } +} + +TEST(TensorCpuViewTest, TestWriteUInt8ThenReadUInt16) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 123; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kUInt16}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), uint16_t{123} << 8); + } +} + +TEST(TensorCpuViewTest, TestWriteUInt16ThenReadUInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kUInt16}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = uint16_t{123} << 8; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 123); + } +} + +TEST(TensorCpuViewTest, TestWriteNegativeInt8ThenReadUInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = -123; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 0); + } +} + +TEST(TensorCpuViewTest, TestWritePositiveInt8ThenReadUInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 123; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 123 * 2); + } +} + +TEST(TensorCpuViewTest, TestDequantization) { + constexpr int num_elements = 20; + // Gives quantization values in range [-100, 90]. + constexpr int zero_point = -100; + constexpr float scale = 2.0f; + Tensor2 tensor{Tensor2::Shape({num_elements})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = { + .format = TensorBufferDescriptor::Format::kQuantizedInt8, + .quantization_parameters = {.scale = scale, + .zero_point = zero_point}}})); + ASSERT_NE(view->data(), nullptr); + auto data = view->data(); + for (int i = 0; i < num_elements; ++i) { + // Add some bias (+1) to make round-up take place. + data[i] = (i * 20 + 1) / scale + zero_point; + } + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + std::vector reference(num_elements); + for (int i = 0; i < num_elements; ++i) { + reference[i] = i * 20.0f + 1.0f; + } + EXPECT_THAT(absl::Span(view->data(), num_elements), + testing::Pointwise(NearWithPrecision(1.001), reference)); + } +} + +TEST(TensorCpuViewTest, TestQuantization) { + constexpr int num_elements = 20; + // Gives quantization values in range [-100, 90]. + constexpr int zero_point = -100; + constexpr float scale = 2.0f; + Tensor2 tensor{Tensor2::Shape({num_elements})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + auto data = view->data(); + for (int i = 0; i < num_elements; ++i) { + // Add some bias (+1) to make round-up take place. + data[i] = i * 20 + 1; + } + } + { + TensorCpuViewDescriptor d{ + .buffer = {.format = TensorBufferDescriptor::Format::kQuantizedInt8, + .quantization_parameters = {.scale = scale, + .zero_point = zero_point}}}; + MP_ASSERT_OK_AND_ASSIGN( + auto view, tensor.GetView(d)); + ASSERT_NE(view->data(), nullptr); + std::vector reference(num_elements); + for (int i = 0; i < num_elements; ++i) { + reference[i] = (i * 20 + 1) / scale + zero_point; + } + EXPECT_THAT(absl::Span(view->data(), num_elements), + testing::Pointwise(IntegerEqual(0), reference)); + } +} + +} // namespace mediapipe From ed6abbbe43df13ddb1145ee94cec681dbf9d6473 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 16:21:28 +0530 Subject: [PATCH 361/469] Added iOS text classifier options --- .../tasks/ios/text/text_classifier/BUILD | 27 ++++++++ .../sources/MPPTextClassifierOptions.h | 62 +++++++++++++++++++ .../sources/MPPTextClassifierOptions.m | 40 ++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/BUILD create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD new file mode 100644 index 000000000..dff39baab --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -0,0 +1,27 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierOptions", + srcs = ["sources/MPPTextClassifierOptions.m"], + hdrs = ["sources/MPPTextClassifierOptions.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskOptions", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h new file mode 100644 index 000000000..d43d801d4 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -0,0 +1,62 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options for setting up a `MPPTextClassifierOptions`. + */ +NS_SWIFT_NAME(TextClassifierOptions) +@interface MPPTextClassifierOptions : MPPTaskOptions + +/** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ +@property(nonatomic, copy) NSString *displayNamesLocale; + +/** + * The maximum number of top-scored classification results to return. If < 0, + * all available results will be returned. If 0, an invalid argument error is + * returned. + */ +@property(nonatomic) NSInteger maxResults; + +/** + * Score threshold to override the one provided in the model metadata (if any). + * Results below this value are rejected. + */ +@property(nonatomic) float scoreThreshold; + +/** + * The allowlist of category names. If non-empty, detection results whose + * category name is not in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryDenylist. + */ +@property(nonatomic, copy) NSArray *categoryAllowlist; + +/** + * The denylist of category names. If non-empty, detection results whose + * category name is in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryAllowlist. + */ +@property(nonatomic, copy) NSArray *categoryDenylist; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m new file mode 100644 index 000000000..2d5c17cda --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m @@ -0,0 +1,40 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +@implementation MPPTextClassifierOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _maxResults = -1; + _scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTextClassifierOptions *textClassifierOptions = [super copyWithZone:zone]; + + textClassifierOptions.scoreThreshold = self.scoreThreshold; + textClassifierOptions.maxResults = self.maxResults; + textClassifierOptions.categoryDenylist = self.categoryDenylist; + textClassifierOptions.categoryAllowlist = self.categoryAllowlist; + textClassifierOptions.displayNamesLocale = self.displayNamesLocale; + + return textClassifierOptions; +} + +@end From 1161ebce9d720e544d3cf740b5e6a5aa446979ed Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 16:22:09 +0530 Subject: [PATCH 362/469] Added iOS text classifier result --- .../tasks/ios/text/text_classifier/BUILD | 10 +++++ .../sources/MPPTextClassifierResult.h | 44 +++++++++++++++++++ .../sources/MPPTextClassifierResult.m | 28 ++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index dff39baab..59ef601bf 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -25,3 +25,13 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifierResult", + srcs = ["sources/MPPTextClassifierResult.m"], + hdrs = ["sources/MPPTextClassifierResult.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskResult", + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h new file mode 100644 index 000000000..63bb92352 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the classification results generated by `MPPTextClassifier`. **/ +NS_SWIFT_NAME(TextClassifierResult) +@interface MPPTextClassifierResult : MPPTaskResult + +/** The `MPPClassificationResult` instance containing one set of results per classifier head. **/ +@property(nonatomic, readonly) MPPClassificationResult *classificationResult; + +/** + * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and + * timestamp (in milliseconds). + * + * @param classificationResult The `MPPClassificationResult` instance containing one set of results + * per classifier head. + * @param timestampMs The timestamp for this result. + * + * @return An instance of `MPPTextClassifierResult` initialized with the given + * `MPPClassificationResult` and timestamp (in milliseconds). + */ +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timestampMs:(NSInteger)timestampMs; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m new file mode 100644 index 000000000..4d5c1104a --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +@implementation MPPTextClassifierResult + +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; + if (self) { + _classificationResult = classificationResult; + } + return self; +} + +@end From 54161cc1abaa11701bc2a51d8ef331db55db0b19 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:22:02 +0530 Subject: [PATCH 363/469] Added iOS text classifier options helpers --- .../ios/text/text_classifier/utils/BUILD | 31 ++++++++++ .../MPPTextClassifierOptions+Helpers.h | 26 +++++++++ .../MPPTextClassifierOptions+Helpers.mm | 56 +++++++++++++++++++ 3 files changed, 113 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/BUILD create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD new file mode 100644 index 000000000..9b01c763e --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierOptionsHelpers", + srcs = ["sources/MPPTextClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h new file mode 100644 index 000000000..1e52e5c87 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm new file mode 100644 index 000000000..c370f11ef --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -0,0 +1,56 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using TextClassifierGraphOptionsProto = + ::mediapipe::tasks::text::text_classifier::proto::TextClassifierGraphOptions; +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} // namespace + +@implementation MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + TextClassifierGraphOptionsProto *graphOptions = + optionsProto->MutableExtension(TextClassifierGraphOptionsProto::ext); + [self.baseOptions copyToProto:graphOptions->mutable_base_options()]; + + ClassifierOptionsProto *classifierOptionsProto = graphOptions->mutable_classifier_options(); + classifierOptionsProto->Clear(); + + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + + for (NSString *category in self.categoryAllowlist) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.categoryDenylist) { + classifierOptionsProto->add_category_denylist(category.cppString); + } + +} + +@end From a0220de2338e4fbc308c95d9fef91383dd817ca4 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:22:20 +0530 Subject: [PATCH 364/469] Added iOS text classifier result helpers --- .../ios/text/text_classifier/utils/BUILD | 10 +++++ .../sources/MPPTextClassifierResult+Helpers.h | 28 ++++++++++++ .../MPPTextClassifierResult+Helpers.mm | 43 +++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index 9b01c763e..299050b32 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -29,3 +29,13 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifierResultHelpers", + srcs = ["sources/MPPTextClassifierResult+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/framework:packet", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h new file mode 100644 index 000000000..f1b728b0a --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm new file mode 100644 index 000000000..62e0d8cb1 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::Packet; +} // namespace + +#define int kMicroSecondsPerMilliSecond = 1000; + +@implementation MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:(const Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; + + return [[MPPTextClassifierResult alloc] + initWithClassificationResult:classificationResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + +@end From b1ded2f700a424a2c6782a4f571bcd70e554fc6b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:22:33 +0530 Subject: [PATCH 365/469] Added iOS text classifier --- .../tasks/ios/text/text_classifier/BUILD | 21 ++++ .../sources/MPPTextClassifier.h | 103 ++++++++++++++++++ .../sources/MPPTextClassifier.mm | 98 +++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index 59ef601bf..e5242f50d 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -35,3 +35,24 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifier", + srcs = ["sources/MPPTextClassifier.mm"], + hdrs = ["sources/MPPTextClassifier.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ":MPPTextClassifierOptions", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h new file mode 100644 index 000000000..10bccad3d --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -0,0 +1,103 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs classification on text. + * + * This API expects a TFLite model with (optional) [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory + * (described below) input tensors, output tensor, and the optional (but recommended) label + * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + * Metadata is required for models with int32 input tensors because it contains the input + * process unit for the model's Tokenizer. No metadata is required for models with string + * input tensors. + * + * Input tensors + * - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires + * a Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor (`kTfLiteString`) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor (`kTfLiteFloat32/kBool`) with: + * - `N` classes and shape `[1 x N]` + * - optional (but recommended) label map(s) as AssociatedFiles with type + * TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill + * the `categoryName` field of the results. The `displayName` field is filled from the + * AssociatedFile (if any) whose locale matches the `displayNamesLocale` field of the + * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If + * none of these are available, only the `index` field of the results will be filled. + */ +NS_SWIFT_NAME(TextClassifier) +@interface MPPTextClassifier : NSObject + +/** + * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `MPPTextClassifierOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the + * device. + * @param error An optional error parameter populated when there is an error in initializing + * the text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an + * error in initializing the text classifier. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. + * + * @param options The options of type `MPPTextClassifierOptions` to use for configuring the + * `MPPTextClassifier`. + * @param error An optional error parameter populated when there is an error in initializing + * the text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an + * error in initializing the text classifier. + */ +- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs classification on the input text. + * + * @param text The `NSString` on which classification is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * classification on the input text. + * + * @return A `MPPTextClassifierResult` object that contains a list of text classifications. + */ +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm new file mode 100644 index 000000000..aed05ec37 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -0,0 +1,98 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +@interface MPPTextClassifier () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPTextClassifier + +- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + + self = [super init]; + + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + + if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + return nil; + } + + return [MPPTextClassifierResult + textClassifierResultWithClassificationsPacket:statusOrOutputPacketMap.value() + [kClassificationsStreamName.cppString]]; +} + +@end From fe05a8d201a12011dc7cc82eaf4dc0f4fad42b20 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:24:17 +0530 Subject: [PATCH 366/469] Reformatted code --- .../sources/MPPTextClassifier.h | 20 +++++++++---------- .../sources/MPPTextClassifier.mm | 6 +++--- .../sources/MPPTextClassifierResult.h | 2 +- .../MPPTextClassifierOptions+Helpers.mm | 1 - 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 10bccad3d..48498edca 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -22,19 +22,19 @@ NS_ASSUME_NONNULL_BEGIN /** * @brief Performs classification on text. - * + * * This API expects a TFLite model with (optional) [TFLite Model * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory - * (described below) input tensors, output tensor, and the optional (but recommended) label + * (described below) input tensors, output tensor, and the optional (but recommended) label * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. * - * Metadata is required for models with int32 input tensors because it contains the input - * process unit for the model's Tokenizer. No metadata is required for models with string + * Metadata is required for models with int32 input tensors because it contains the input + * process unit for the model's Tokenizer. No metadata is required for models with string * input tensors. * * Input tensors * - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]` - * representing the input ids, mask ids, and segment ids. This input signature requires + * representing the input ids, mask ids, and segment ids. This input signature requires * a Bert Tokenizer process unit in the model metadata. * - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing * the input ids. This input signature requires a Regex Tokenizer process unit in the @@ -44,12 +44,12 @@ NS_ASSUME_NONNULL_BEGIN * * At least one output tensor (`kTfLiteFloat32/kBool`) with: * - `N` classes and shape `[1 x N]` - * - optional (but recommended) label map(s) as AssociatedFiles with type + * - optional (but recommended) label map(s) as AssociatedFiles with type * TENSOR_AXIS_LABELS, - * containing one label per line. The first such AssociatedFile (if any) is used to fill - * the `categoryName` field of the results. The `displayName` field is filled from the - * AssociatedFile (if any) whose locale matches the `displayNamesLocale` field of the - * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If + * containing one label per line. The first such AssociatedFile (if any) is used to fill + * the `categoryName` field of the results. The `displayName` field is filled from the + * AssociatedFile (if any) whose locale matches the `displayNamesLocale` field of the + * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If * none of these are available, only the `index` field of the results will be filled. */ NS_SWIFT_NAME(TextClassifier) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index aed05ec37..59b5423bb 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -62,11 +62,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T _textTaskRunner = [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; - + if (!_textTaskRunner) { return nil; - } - + } + self = [super init]; return self; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h index 63bb92352..6744a8e16 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -26,7 +26,7 @@ NS_SWIFT_NAME(TextClassifierResult) @property(nonatomic, readonly) MPPClassificationResult *classificationResult; /** - * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and + * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and * timestamp (in milliseconds). * * @param classificationResult The `MPPClassificationResult` instance containing one set of results diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm index c370f11ef..de64d970c 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -50,7 +50,6 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto for (NSString *category in self.categoryDenylist) { classifierOptionsProto->add_category_denylist(category.cppString); } - } @end From c7e36f87207731c02c4d1b72399491c2c6d73f24 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:31:46 +0530 Subject: [PATCH 367/469] Re-ordered dependencies in build file --- mediapipe/tasks/ios/text/text_classifier/BUILD | 16 ++++++++-------- .../tasks/ios/text/text_classifier/utils/BUILD | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index e5242f50d..a6315840b 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -30,8 +30,8 @@ objc_library( srcs = ["sources/MPPTextClassifierResult.m"], hdrs = ["sources/MPPTextClassifierResult.h"], deps = [ - "//mediapipe/tasks/ios/core:MPPTaskResult", "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) @@ -44,15 +44,15 @@ objc_library( "-std=c++17", ], deps = [ + ":MPPTextClassifierOptions", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", - "//mediapipe/tasks/ios/core:MPPTaskOptions", - "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", - "//mediapipe/tasks/ios/core:MPPTextPacketCreator", - "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", - "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - ":MPPTextClassifierOptions", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index 299050b32..23627391c 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -21,11 +21,11 @@ objc_library( srcs = ["sources/MPPTextClassifierOptions+Helpers.mm"], hdrs = ["sources/MPPTextClassifierOptions+Helpers.h"], deps = [ - "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", - "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", - "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", ], ) @@ -34,8 +34,8 @@ objc_library( srcs = ["sources/MPPTextClassifierResult+Helpers.mm"], hdrs = ["sources/MPPTextClassifierResult+Helpers.h"], deps = [ - "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", - "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", "//mediapipe/framework:packet", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", ], ) From 0e56bd38f3123ced3de8c0c2862b7f7c55549078 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 12:54:58 -0800 Subject: [PATCH 368/469] Fix for CHECK failure due to pointer description sometimes being larger than allocated string space PiperOrigin-RevId: 501355568 --- mediapipe/framework/tool/sink.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/tool/sink.cc b/mediapipe/framework/tool/sink.cc index 4a181b43f..f8abf4925 100644 --- a/mediapipe/framework/tool/sink.cc +++ b/mediapipe/framework/tool/sink.cc @@ -87,7 +87,8 @@ void AddVectorSink(const std::string& stream_name, // node->mutable_options()->MutableExtension( CallbackPacketCalculatorOptions::ext); options->set_type(CallbackPacketCalculatorOptions::VECTOR_PACKET); - char address[17]; + // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. + char address[19]; int written = snprintf(address, sizeof(address), "%p", dumped_data); CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); @@ -112,7 +113,8 @@ void AddPostStreamPacketSink(const std::string& stream_name, node->mutable_options()->MutableExtension( CallbackPacketCalculatorOptions::ext); options->set_type(CallbackPacketCalculatorOptions::POST_STREAM_PACKET); - char address[17]; + // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. + char address[19]; int written = snprintf(address, sizeof(address), "%p", post_stream_packet); CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); From 5612af68cdb6cd157d8a86f13425913521e2de49 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 13:00:37 -0800 Subject: [PATCH 369/469] Propagate compatible_with for drishti_proto_library PiperOrigin-RevId: 501356895 --- mediapipe/framework/tool/mediapipe_graph.bzl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/tool/mediapipe_graph.bzl b/mediapipe/framework/tool/mediapipe_graph.bzl index 45d98b1eb..ef5182a53 100644 --- a/mediapipe/framework/tool/mediapipe_graph.bzl +++ b/mediapipe/framework/tool/mediapipe_graph.bzl @@ -67,7 +67,8 @@ def data_as_c_string( name, srcs, outs = None, - testonly = None): + testonly = None, + compatible_with = None): """Encodes the data from a file as a C string literal. This produces a text file containing the quoted C string literal. It can be @@ -79,6 +80,7 @@ def data_as_c_string( outs: A list containing a single item, the name of the output text file. Defaults to the rule name. testonly: pass 1 if the graph is to be used only for tests. + compatible_with: a list of environments the rule is compatible with. """ if len(srcs) != 1: fail("srcs must be a single-element list") @@ -92,6 +94,7 @@ def data_as_c_string( cmd = "$(location %s) \"$<\" > \"$@\"" % encode_as_c_string, tools = [encode_as_c_string], testonly = testonly, + compatible_with = compatible_with, ) def mediapipe_simple_subgraph( @@ -208,6 +211,7 @@ def mediapipe_options_library( deps = [], visibility = None, testonly = None, + compatible_with = None, **kwargs): """Registers options protobuf metadata for defining options packets. @@ -217,6 +221,7 @@ def mediapipe_options_library( deps: any additional protobuf dependencies. visibility: The list of packages the subgraph should be visible to. testonly: pass 1 if the graph is to be used only for tests. + compatible_with: a list of environments the rule is compatible with. **kwargs: Remaining keyword args, forwarded to cc_library. """ @@ -224,16 +229,19 @@ def mediapipe_options_library( name = proto_lib + "_transitive", deps = [proto_lib], testonly = testonly, + compatible_with = compatible_with, ) direct_descriptor_set( name = proto_lib + "_direct", deps = [proto_lib], testonly = testonly, + compatible_with = compatible_with, ) data_as_c_string( name = name + "_inc", srcs = [proto_lib + "_transitive-transitive-descriptor-set.proto.bin"], outs = [proto_lib + "_descriptors.inc"], + compatible_with = compatible_with, ) native.genrule( name = name + "_type_name", @@ -245,6 +253,7 @@ def mediapipe_options_library( tools = ["//mediapipe/framework/tool:message_type_util"], visibility = visibility, testonly = testonly, + compatible_with = compatible_with, ) expand_template( name = name + "_cc", @@ -256,6 +265,7 @@ def mediapipe_options_library( "{{DESCRIPTOR_INC_FILE_PATH}}": native.package_name() + "/" + proto_lib + "_descriptors.inc", }, testonly = testonly, + compatible_with = compatible_with, ) native.cc_library( name = proto_lib.replace("_proto", "_options_registry"), @@ -274,6 +284,7 @@ def mediapipe_options_library( visibility = visibility, testonly = testonly, features = ["-no_undefined"], + compatible_with = compatible_with, **kwargs ) mediapipe_reexport_library( From 36be94f861e57dd21e4e7a64e5a1650e73313753 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 14:21:02 -0800 Subject: [PATCH 370/469] Internal change PiperOrigin-RevId: 501378130 --- .../desktop/autoflip/autoflip_messages.proto | 4 +++ .../calculators/scene_cropping_calculator.cc | 21 +++++++++-- .../scene_cropping_calculator_test.cc | 35 +++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/mediapipe/examples/desktop/autoflip/autoflip_messages.proto b/mediapipe/examples/desktop/autoflip/autoflip_messages.proto index 8507c9ad7..c89a6aea6 100644 --- a/mediapipe/examples/desktop/autoflip/autoflip_messages.proto +++ b/mediapipe/examples/desktop/autoflip/autoflip_messages.proto @@ -185,6 +185,10 @@ message ExternalRenderFrame { // original dimensions of the input video. The first step to render this // frame is to crop this rect from the input frame. optional Rect crop_from_location = 1; + // Rect that must be cropped out of the input frame. It is defined in the + // ratio of the frame of the input video. The first step to render this frame + // is to crop this rect from the input frame. + optional Rect normalized_crop_from_location = 7; // The placement location where the above rect is placed on the output frame. // This will always have the same aspect ratio as the above rect but scaling // may be required. diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index 89170dc6a..7e286b743 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -201,13 +201,26 @@ absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, void ConstructExternalRenderMessage( const cv::Rect& crop_from_location, const cv::Rect& render_to_location, const cv::Scalar& padding_color, const uint64 timestamp_us, - ExternalRenderFrame* external_render_message) { + ExternalRenderFrame* external_render_message, int frame_width, + int frame_height) { auto crop_from_message = external_render_message->mutable_crop_from_location(); crop_from_message->set_x(crop_from_location.x); crop_from_message->set_y(crop_from_location.y); crop_from_message->set_width(crop_from_location.width); crop_from_message->set_height(crop_from_location.height); + + auto normalized_crop_from_message = + external_render_message->mutable_normalized_crop_from_location(); + normalized_crop_from_message->set_x(crop_from_location.x / + static_cast(frame_width)); + normalized_crop_from_message->set_y(crop_from_location.y / + static_cast(frame_height)); + normalized_crop_from_message->set_width(crop_from_location.width / + static_cast(frame_width)); + normalized_crop_from_message->set_height(crop_from_location.height / + static_cast(frame_height)); + auto render_to_message = external_render_message->mutable_render_to_location(); render_to_message->set_x(render_to_location.x); @@ -627,7 +640,8 @@ absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, auto external_render_message = absl::make_unique(); ConstructExternalRenderMessage( crop_from_locations[i], render_to_locations[i], padding_colors[i], - scene_frame_timestamps_[i], external_render_message.get()); + scene_frame_timestamps_[i], external_render_message.get(), + frame_width_, frame_height_); cc->Outputs() .Tag(kExternalRenderingPerFrame) .Add(external_render_message.release(), @@ -640,7 +654,8 @@ absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, ExternalRenderFrame render_frame; ConstructExternalRenderMessage(crop_from_locations[i], render_to_locations[i], padding_colors[i], - scene_frame_timestamps_[i], &render_frame); + scene_frame_timestamps_[i], &render_frame, + frame_width_, frame_height_); external_render_list_->push_back(render_frame); } } diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index 88728860a..c3285ea58 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -920,6 +920,41 @@ TEST(SceneCroppingCalculatorTest, OutputsCropMessageKinematicPathNoVideo) { EXPECT_EQ(ext_render_message.render_to_location().height(), 1124); } } + +// Checks external render message with default poly path solver using +// normalized crops. +TEST(SceneCroppingCalculatorTest, OutputsCropMessagePolyPathNormalized) { + const CalculatorGraphConfig::Node config = + ParseTextProtoOrDie( + absl::Substitute(kExternalRenderConfig, kTargetWidth, kTargetHeight)); + auto runner = absl::make_unique(config); + const int num_frames = kSceneSize; + AddScene(0, num_frames, kInputFrameWidth, kInputFrameHeight, kKeyFrameWidth, + kKeyFrameHeight, 1, runner->MutableInputs()); + + MP_EXPECT_OK(runner->Run()); + const auto& outputs = runner->Outputs(); + const auto& ext_render_per_frame = + outputs.Tag(kExternalRenderingPerFrameTag).packets; + EXPECT_EQ(ext_render_per_frame.size(), num_frames); + + for (int i = 0; i < num_frames - 1; ++i) { + const auto& ext_render_message = + ext_render_per_frame[i].Get(); + EXPECT_EQ(ext_render_message.timestamp_us(), i * 20000); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().x(), + 725 / static_cast(kInputFrameWidth)); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().y(), 0); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().width(), + 461 / static_cast(kInputFrameWidth)); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().height(), + 720 / static_cast(kInputFrameHeight)); + EXPECT_EQ(ext_render_message.render_to_location().x(), 0); + EXPECT_EQ(ext_render_message.render_to_location().y(), 0); + EXPECT_EQ(ext_render_message.render_to_location().width(), 720); + EXPECT_EQ(ext_render_message.render_to_location().height(), 1124); + } +} } // namespace } // namespace autoflip } // namespace mediapipe From 8830eefa0b96ccc886feb53e1404bae2e0cdf4d1 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 11 Jan 2023 16:04:57 -0800 Subject: [PATCH 371/469] Internal change. PiperOrigin-RevId: 501403332 --- .../formats/tensor/cpu_buffer_converters.cc | 240 --------------- .../tensor/cpu_buffer_converters_test.cc | 282 ------------------ 2 files changed, 522 deletions(-) delete mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters.cc delete mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc deleted file mode 100644 index e4e705be5..000000000 --- a/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc +++ /dev/null @@ -1,240 +0,0 @@ -#include -#include -#include - -#include "mediapipe/framework/formats/tensor/backend.h" -#include "mediapipe/framework/formats/tensor/tensor2.h" -#include "mediapipe/framework/formats/tensor/views/buffer.h" -#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" -#include "third_party/FP16/include/fp16.h" - -namespace mediapipe { -namespace { - -template -auto ConverterCheckFunction() { - return - [](const Tensor2& tensor, uint64_t source_descriptor_type_id, - const Tensor2::ViewDescriptor& source_base_descriptor, - uint64_t destination_descriptor_type_id, - const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { - if (source_descriptor_type_id != TensorCpuView::kId || - destination_descriptor_type_id != TensorCpuView::kId) - return false; - auto source_descriptor = - static_cast(source_base_descriptor); - auto destination_descriptor = - static_cast( - destination_base_descriptor); - return source_descriptor.buffer.format == - TensorTypeToFormat::value && - destination_descriptor.buffer.format == - TensorTypeToFormat::value; - }; -} - -template -auto ConvertFunction() { - return [](const Tensor2& tensor, const Tensor2::View& source_base_view, - const Tensor2::View& destination_base_view) -> bool { - auto source = source_base_view.DownCast(); - auto destination = destination_base_view.DownCast(); - if (source->descriptor().buffer.format == - destination->descriptor().buffer.format) { - std::memcpy( - destination->data(), source->data(), - TensorBufferSize(destination->descriptor().buffer, tensor.shape())); - } else { - auto source_pointer = source->data(); - auto destination_pointer = destination->data(); - for (int i = 0; i < tensor.shape().NumElements(); i++) { - *destination_pointer++ = - GpuLikeTypeCast(*source_pointer++); - } - } - return true; - }; -} - -#define REGISTER_CONVERTER(SourceType, DestinationType) \ - TENSOR_REGISTER_CONVERTER( \ - {ConverterCheckFunction(), \ - ConvertFunction()}); - -REGISTER_CONVERTER(float, Float16); -REGISTER_CONVERTER(float, int8_t); -REGISTER_CONVERTER(float, uint8_t); -REGISTER_CONVERTER(float, int16_t); -REGISTER_CONVERTER(float, uint16_t); -REGISTER_CONVERTER(float, int32_t); -REGISTER_CONVERTER(float, uint32_t); - -REGISTER_CONVERTER(Float16, float); -REGISTER_CONVERTER(Float16, int8_t); -REGISTER_CONVERTER(Float16, uint8_t); -REGISTER_CONVERTER(Float16, int16_t); -REGISTER_CONVERTER(Float16, uint16_t); -REGISTER_CONVERTER(Float16, int32_t); -REGISTER_CONVERTER(Float16, uint32_t); - -REGISTER_CONVERTER(int8_t, float); -REGISTER_CONVERTER(int8_t, Float16); -REGISTER_CONVERTER(int8_t, uint8_t); -REGISTER_CONVERTER(int8_t, int16_t); -REGISTER_CONVERTER(int8_t, uint16_t); -REGISTER_CONVERTER(int8_t, int32_t); -REGISTER_CONVERTER(int8_t, uint32_t); - -REGISTER_CONVERTER(uint8_t, float); -REGISTER_CONVERTER(uint8_t, Float16); -REGISTER_CONVERTER(uint8_t, int8_t); -REGISTER_CONVERTER(uint8_t, int16_t); -REGISTER_CONVERTER(uint8_t, uint16_t); -REGISTER_CONVERTER(uint8_t, int32_t); -REGISTER_CONVERTER(uint8_t, uint32_t); - -REGISTER_CONVERTER(int16_t, float); -REGISTER_CONVERTER(int16_t, Float16); -REGISTER_CONVERTER(int16_t, int8_t); -REGISTER_CONVERTER(int16_t, uint8_t); -REGISTER_CONVERTER(int16_t, uint16_t); -REGISTER_CONVERTER(int16_t, uint32_t); -REGISTER_CONVERTER(int16_t, uint32_t); - -REGISTER_CONVERTER(uint16_t, float); -REGISTER_CONVERTER(uint16_t, Float16); -REGISTER_CONVERTER(uint16_t, int8_t); -REGISTER_CONVERTER(uint16_t, uint8_t); -REGISTER_CONVERTER(uint16_t, int16_t); -REGISTER_CONVERTER(uint16_t, int32_t); -REGISTER_CONVERTER(uint16_t, uint32_t); - -REGISTER_CONVERTER(int32_t, float); -REGISTER_CONVERTER(int32_t, Float16); -REGISTER_CONVERTER(int32_t, int8_t); -REGISTER_CONVERTER(int32_t, uint8_t); -REGISTER_CONVERTER(int32_t, int16_t); -REGISTER_CONVERTER(int32_t, uint16_t); -REGISTER_CONVERTER(int32_t, uint32_t); - -REGISTER_CONVERTER(uint32_t, float); -REGISTER_CONVERTER(uint32_t, Float16); -REGISTER_CONVERTER(uint32_t, int8_t); -REGISTER_CONVERTER(uint32_t, uint8_t); -REGISTER_CONVERTER(uint32_t, int16_t); -REGISTER_CONVERTER(uint32_t, uint16_t); -REGISTER_CONVERTER(uint32_t, int32_t); - -template -auto DequantizationCheckFunction() { - return - [](const Tensor2& tensor, uint64_t source_descriptor_type_id, - const Tensor2::ViewDescriptor& source_base_descriptor, - uint64_t destination_descriptor_type_id, - const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { - if (source_descriptor_type_id != TensorCpuView::kId || - destination_descriptor_type_id != TensorCpuView::kId) - return false; - auto source_descriptor = - static_cast(source_base_descriptor); - auto destination_descriptor = - static_cast( - destination_base_descriptor); - return source_descriptor.buffer.format == - TensorBufferDescriptor::Format::kQuantizedInt8 && - destination_descriptor.buffer.format == - TensorTypeToFormat::value; - }; -} - -template -auto DequantizationConvertFunction() { - return [](const Tensor2& tensor, const Tensor2::View& source_base_view, - const Tensor2::View& destination_base_view) -> bool { - auto source = source_base_view.DownCast(); - auto destination = destination_base_view.DownCast(); - auto source_pointer = source->data(); - auto destination_pointer = destination->data(); - int zero_point = - source->descriptor().buffer.quantization_parameters.zero_point; - float scale = source->descriptor().buffer.quantization_parameters.scale; - for (int i = 0; i < tensor.shape().NumElements(); i++) { - *destination_pointer++ = static_cast( - (*source_pointer++ - zero_point) * scale); - } - return true; - }; -} - -#define REGISTER_DEQUANTIZATION_CONVERTER(DestinationType) \ - TENSOR_REGISTER_CONVERTER( \ - {DequantizationCheckFunction(), \ - DequantizationConvertFunction()}); - -REGISTER_DEQUANTIZATION_CONVERTER(float); -REGISTER_DEQUANTIZATION_CONVERTER(Float16); -REGISTER_DEQUANTIZATION_CONVERTER(int8_t); -REGISTER_DEQUANTIZATION_CONVERTER(uint8_t); -REGISTER_DEQUANTIZATION_CONVERTER(int16_t); -REGISTER_DEQUANTIZATION_CONVERTER(uint16_t); -REGISTER_DEQUANTIZATION_CONVERTER(int32_t); -REGISTER_DEQUANTIZATION_CONVERTER(uint32_t); - -template -auto QuantizationCheckFunction() { - return - [](const Tensor2& tensor, uint64_t source_descriptor_type_id, - const Tensor2::ViewDescriptor& source_base_descriptor, - uint64_t destination_descriptor_type_id, - const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { - if (source_descriptor_type_id != TensorCpuView::kId || - destination_descriptor_type_id != TensorCpuView::kId) - return false; - auto source_descriptor = - static_cast(source_base_descriptor); - auto destination_descriptor = - static_cast( - destination_base_descriptor); - bool same = source_descriptor.buffer.format == - TensorTypeToFormat::value && - destination_descriptor.buffer.format == - TensorBufferDescriptor::Format::kQuantizedInt8; - return same; - }; -} - -template -auto QuantizationConvertFunction() { - return [](const Tensor2& tensor, const Tensor2::View& source_base_view, - const Tensor2::View& destination_base_view) -> bool { - auto source = source_base_view.DownCast(); - auto destination = destination_base_view.DownCast(); - auto source_pointer = source->data(); - auto destination_pointer = destination->data(); - int zero_point = - destination->descriptor().buffer.quantization_parameters.zero_point; - float scale = - destination->descriptor().buffer.quantization_parameters.scale; - for (int i = 0; i < tensor.shape().NumElements(); i++) { - *destination_pointer++ = - static_cast(*source_pointer++ / scale + zero_point); - } - return true; - }; -} - -#define REGISTER_QUANTIZATION_CONVERTER(SourceType) \ - TENSOR_REGISTER_CONVERTER({QuantizationCheckFunction(), \ - QuantizationConvertFunction()}); - -REGISTER_QUANTIZATION_CONVERTER(float); -REGISTER_QUANTIZATION_CONVERTER(Float16); -REGISTER_QUANTIZATION_CONVERTER(int8_t); -REGISTER_QUANTIZATION_CONVERTER(uint8_t); -REGISTER_QUANTIZATION_CONVERTER(int16_t); -REGISTER_QUANTIZATION_CONVERTER(uint16_t); -REGISTER_QUANTIZATION_CONVERTER(int32_t); -REGISTER_QUANTIZATION_CONVERTER(uint32_t); - -} // namespace -} // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc deleted file mode 100644 index 3619ad531..000000000 --- a/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc +++ /dev/null @@ -1,282 +0,0 @@ -#include - -#include "mediapipe/framework/formats/tensor/tensor2.h" -#include "mediapipe/framework/formats/tensor/views/buffer.h" -#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" - -MATCHER_P(NearWithPrecision, precision, "") { - return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; -} -MATCHER_P(IntegerEqual, precision, "") { - return std::get<0>(arg) == std::get<1>(arg); -} - -namespace mediapipe { - -TEST(TensorCpuViewTest, TestWrite32ThenRead16) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 1234.0f; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat16}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 1234.0f); - } -} - -TEST(TensorCpuViewTest, TestWrite16ThenRead32) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat16}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 1234.0f; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 1234.0f); - } -} - -TEST(TensorCpuViewTest, TestWriteFloat32ThenReadInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 0.121569f; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ( - *view->data(), - static_cast(0.121569f * std::numeric_limits::max())); - } -} - -TEST(TensorCpuViewTest, TestWriteInt8ThenReadFloat32) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = - static_cast(0.123f * std::numeric_limits::max()); - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_NEAR(*view->data(), 0.123f, - 1.0f / std::numeric_limits::max()); - } -} - -TEST(TensorCpuViewTest, TestWriteUInt8ThenReadUInt16) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 123; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kUInt16}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), uint16_t{123} << 8); - } -} - -TEST(TensorCpuViewTest, TestWriteUInt16ThenReadUInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kUInt16}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = uint16_t{123} << 8; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 123); - } -} - -TEST(TensorCpuViewTest, TestWriteNegativeInt8ThenReadUInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = -123; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 0); - } -} - -TEST(TensorCpuViewTest, TestWritePositiveInt8ThenReadUInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 123; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 123 * 2); - } -} - -TEST(TensorCpuViewTest, TestDequantization) { - constexpr int num_elements = 20; - // Gives quantization values in range [-100, 90]. - constexpr int zero_point = -100; - constexpr float scale = 2.0f; - Tensor2 tensor{Tensor2::Shape({num_elements})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = { - .format = TensorBufferDescriptor::Format::kQuantizedInt8, - .quantization_parameters = {.scale = scale, - .zero_point = zero_point}}})); - ASSERT_NE(view->data(), nullptr); - auto data = view->data(); - for (int i = 0; i < num_elements; ++i) { - // Add some bias (+1) to make round-up take place. - data[i] = (i * 20 + 1) / scale + zero_point; - } - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - std::vector reference(num_elements); - for (int i = 0; i < num_elements; ++i) { - reference[i] = i * 20.0f + 1.0f; - } - EXPECT_THAT(absl::Span(view->data(), num_elements), - testing::Pointwise(NearWithPrecision(1.001), reference)); - } -} - -TEST(TensorCpuViewTest, TestQuantization) { - constexpr int num_elements = 20; - // Gives quantization values in range [-100, 90]. - constexpr int zero_point = -100; - constexpr float scale = 2.0f; - Tensor2 tensor{Tensor2::Shape({num_elements})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - auto data = view->data(); - for (int i = 0; i < num_elements; ++i) { - // Add some bias (+1) to make round-up take place. - data[i] = i * 20 + 1; - } - } - { - TensorCpuViewDescriptor d{ - .buffer = {.format = TensorBufferDescriptor::Format::kQuantizedInt8, - .quantization_parameters = {.scale = scale, - .zero_point = zero_point}}}; - MP_ASSERT_OK_AND_ASSIGN( - auto view, tensor.GetView(d)); - ASSERT_NE(view->data(), nullptr); - std::vector reference(num_elements); - for (int i = 0; i < num_elements; ++i) { - reference[i] = (i * 20 + 1) / scale + zero_point; - } - EXPECT_THAT(absl::Span(view->data(), num_elements), - testing::Pointwise(IntegerEqual(0), reference)); - } -} - -} // namespace mediapipe From 9cbb76939dd069eacecae103c5c27b6e07c7e9c7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 20:33:26 -0800 Subject: [PATCH 372/469] Adds smaller MobileBERT model. PiperOrigin-RevId: 501451414 --- .../model_maker/models/text_classifier/BUILD | 45 ++++++++++ .../python/text/text_classifier/BUILD | 11 +++ .../python/text/text_classifier/model_spec.py | 13 +-- .../text/text_classifier/model_spec_test.py | 7 +- .../text/text_classifier/testdata/BUILD | 5 +- .../testdata/bert_metadata.json | 84 +++++++++++++++++++ .../text/text_classifier/text_classifier.py | 13 ++- .../text_classifier/text_classifier_test.py | 25 +++++- mediapipe/model_maker/setup.py | 12 ++- third_party/external_files.bzl | 30 +++++++ 10 files changed, 228 insertions(+), 17 deletions(-) create mode 100644 mediapipe/model_maker/models/text_classifier/BUILD create mode 100644 mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json diff --git a/mediapipe/model_maker/models/text_classifier/BUILD b/mediapipe/model_maker/models/text_classifier/BUILD new file mode 100644 index 000000000..4c54bbccc --- /dev/null +++ b/mediapipe/model_maker/models/text_classifier/BUILD @@ -0,0 +1,45 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"], +) + +mediapipe_files( + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) + +filegroup( + name = "mobilebert_tiny", + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 7bb41351e..43f2b6c75 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -53,6 +53,7 @@ py_library( deps = [ ":model_options", "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/utils:file_util", "//mediapipe/model_maker/python/text/core:bert_model_spec", ], ) @@ -88,6 +89,9 @@ py_library( py_test( name = "preprocessor_test", srcs = ["preprocessor_test.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], tags = ["requires-net:external"], deps = [ ":dataset", @@ -109,6 +113,9 @@ py_library( py_library( name = "text_classifier", srcs = ["text_classifier.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":dataset", ":model_options", @@ -130,6 +137,7 @@ py_test( size = "large", srcs = ["text_classifier_test.py"], data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", "//mediapipe/model_maker/python/text/text_classifier/testdata", ], tags = ["requires-net:external"], @@ -151,6 +159,9 @@ py_library( py_binary( name = "text_classifier_demo", srcs = ["text_classifier_demo.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":text_classifier_demo_lib", ], diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 9df7e1039..a6bdd9522 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -18,12 +18,15 @@ import enum import functools from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.text.core import bert_model_spec from mediapipe.model_maker.python.text.text_classifier import model_options as mo # BERT-based text classifier spec inherited from BertModelSpec BertClassifierSpec = bert_model_spec.BertModelSpec +MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/' + @dataclasses.dataclass class AverageWordEmbeddingClassifierSpec: @@ -49,16 +52,14 @@ average_word_embedding_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial( BertClassifierSpec, hparams=hp.BaseHParams( - epochs=3, - batch_size=48, - learning_rate=3e-5, - distribution_strategy='off'), + epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' + ), name='MobileBert', - uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', + uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH), tflite_input_name={ 'ids': 'serving_default_input_1:0', 'mask': 'serving_default_input_3:0', - 'segment_ids': 'serving_default_input_2:0' + 'segment_ids': 'serving_default_input_2:0', }, ) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index dd7f880f3..3ea019b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -28,9 +28,10 @@ class ModelSpecTest(tf.test.TestCase): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertEqual( - model_spec_obj.uri, 'https://tfhub.dev/tensorflow/' - 'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1') + self.assertIn( + 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny', + model_spec_obj.uri, + ) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( model_spec_obj.tflite_input_name, { diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD index 663c72082..a581462cf 100644 --- a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD @@ -19,5 +19,8 @@ package( filegroup( name = "testdata", - srcs = ["average_word_embedding_metadata.json"], + srcs = [ + "average_word_embedding_metadata.json", + "bert_metadata.json", + ], ) diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json new file mode 100644 index 000000000..24214a80d --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json @@ -0,0 +1,84 @@ +{ + "name": "TextClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of the input text.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ], + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "min_parser_version": "1.1.0" +} diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 1a338e345..f6abc8bf0 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -269,16 +269,21 @@ class _AverageWordEmbeddingClassifier(TextClassifier): """Creates an Average Word Embedding model.""" self._model = tf.keras.Sequential([ tf.keras.layers.InputLayer( - input_shape=[self._model_options.seq_len], dtype=tf.int32), + input_shape=[self._model_options.seq_len], + dtype=tf.int32, + name="input_ids", + ), tf.keras.layers.Embedding( len(self._text_preprocessor.get_vocab()), self._model_options.wordvec_dim, - input_length=self._model_options.seq_len), + input_length=self._model_options.seq_len, + ), tf.keras.layers.GlobalAveragePooling1D(), tf.keras.layers.Dense( - self._model_options.wordvec_dim, activation=tf.nn.relu), + self._model_options.wordvec_dim, activation=tf.nn.relu + ), tf.keras.layers.Dropout(self._model_options.dropout_rate), - tf.keras.layers.Dense(self._num_classes, activation="softmax") + tf.keras.layers.Dense(self._num_classes, activation="softmax"), ]) def _save_vocab(self, vocab_filepath: str): diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index eb4443b44..1ae2bc553 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -26,6 +26,9 @@ class TextClassifierTest(tf.test.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( test_utils.get_test_data_path('average_word_embedding_metadata.json')) + _BERT_CLASSIFIER_JSON_FILE = test_utils.get_test_data_path( + 'bert_metadata.json' + ) def _get_data(self): labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) @@ -94,7 +97,27 @@ class TextClassifierTest(tf.test.TestCase): _, accuracy = bert_classifier.evaluate(validation_data) self.assertGreaterEqual(accuracy, 0.0) - # TODO: Add a unit test that does not run OOM. + + # Test export_model + bert_classifier.export_model() + output_metadata_file = os.path.join( + options.hparams.export_dir, 'metadata.json' + ) + output_tflite_file = os.path.join( + options.hparams.export_dir, 'model.tflite' + ) + + self.assertTrue(os.path.exists(output_tflite_file)) + self.assertGreater(os.path.getsize(output_tflite_file), 0) + + self.assertTrue(os.path.exists(output_metadata_file)) + self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() + self.assertTrue( + filecmp.cmp( + output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False + ) + ) def test_label_mismatch(self): options = ( diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index 7114e2080..1dac6301a 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -81,7 +81,10 @@ def _setup_build_dir(): file.write(filedata) # Use bazel to download GCS model files - model_build_files = ['models/gesture_recognizer/BUILD'] + model_build_files = [ + 'models/gesture_recognizer/BUILD', + 'models/text_classifier/BUILD', + ] for model_build_file in model_build_files: build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) os.makedirs(os.path.dirname(build_target_file), exist_ok=True) @@ -95,7 +98,12 @@ def _setup_build_dir(): 'models/gesture_recognizer/gesture_embedder/saved_model.pb', 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', 'models/gesture_recognizer/gesture_embedder/variables/variables.index', - ] + 'models/text_classifier/mobilebert_tiny/keras_metadata.pb', + 'models/text_classifier/mobilebert_tiny/saved_model.pb', + 'models/text_classifier/mobilebert_tiny/assets/vocab.txt', + 'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001', + 'models/text_classifier/mobilebert_tiny/variables/variables.index', + ] for elem in external_files: external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) sys.stderr.write('downloading file: %s\n' % external_file) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 790486676..5adfbdfc6 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -1006,6 +1006,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], ) + http_file( + name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb", + sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/keras_metadata.pb?generation=1673297965144159"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_saved_model_pb", + sha256 = "323c997cd3e17df1b2e3bdebe3cfe2b17c5ffd9488a26a4afb59ee819196837a", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/saved_model.pb?generation=1673297968138825"], + ) + http_file( name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001", sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97", @@ -1053,3 +1065,21 @@ def external_files(): sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt", + sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/assets/vocab.txt?generation=1673297970948751"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_data-00000-of-00001", + sha256 = "c3857370046cd3a2f345657cf1bb259a4e7e09185d7f0808e57803e9d41ebba4", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.data-00000-of-00001?generation=1673297975132568"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_index", + sha256 = "4df4d7c0fefe99903ab6ebf44b7478196ce613082d2ca692a5a37a7f24e562ed", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.index?generation=1673297977586840"], + ) From 5c74ed2ae58eeb6b6f9b18aa47edf52e08a0eccb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 12 Jan 2023 08:27:57 -0800 Subject: [PATCH 373/469] EmbeddingAggregationCalculator should fill in the `timestamp_ms` field of the embedding results in the stream mode. Per user feedback, the consistency between the packet timestamp and the timestamp field of the embedding result helps reducing the confusion. PiperOrigin-RevId: 501572379 --- .../calculators/embedding_aggregation_calculator.cc | 4 +++- .../calculators/embedding_aggregation_calculator_test.cc | 8 +++++--- .../processors/embedding_postprocessing_graph_test.cc | 7 ++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc index bae926b76..6e06c4e32 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc @@ -120,7 +120,9 @@ absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) { } kTimestampedEmbeddingsOut(cc).Send(std::move(results)); } else { - kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc)); + auto result = kEmbeddingsIn(cc).Get(); + result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000); + kEmbeddingsOut(cc).Send(result); } RET_CHECK(cached_embeddings_.empty()); return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc index ebb4d8880..f2b2fa1d5 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc @@ -120,7 +120,7 @@ class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test { CalculatorGraph calculator_graph_; }; -TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutAggregation) { EmbeddingResult embedding = ParseTextProtoOrDie( R"pb(embeddings { head_index: 0 })pb"); @@ -129,10 +129,12 @@ TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { MP_ASSERT_OK(Send(embedding)); MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); - EXPECT_THAT(result, EqualsProto(embedding)); + EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( + R"pb(timestamp_ms: 0 + embeddings { head_index: 0 })pb"))); } -TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) { +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); MP_ASSERT_OK(Send(ParseTextProtoOrDie(R"pb(embeddings { head_index: 0 diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 163e46ee8..809268a63 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -246,7 +246,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { absl::make_unique>(); }; -TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { +TEST_F(PostprocessingTest, SucceedsWithoutAggregation) { // Build graph. proto::EmbedderOptions options; MP_ASSERT_OK_AND_ASSIGN(auto poller, @@ -261,7 +261,8 @@ TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult(poller)); // Validate results. - EXPECT_FALSE(results.has_timestamp_ms()); + EXPECT_TRUE(results.has_timestamp_ms()); + EXPECT_EQ(results.timestamp_ms(), 0); EXPECT_EQ(results.embeddings_size(), 1); EXPECT_EQ(results.embeddings(0).head_index(), 0); EXPECT_EQ(results.embeddings(0).head_name(), "feature"); @@ -273,7 +274,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { } } -TEST_F(PostprocessingTest, SucceedsWithTimestamps) { +TEST_F(PostprocessingTest, SucceedsWithAggregation) { // Build graph. proto::EmbedderOptions options; MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options, From 74b60780c7a5e9fe07c10513201b499a99fd137e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 12 Jan 2023 09:58:34 -0800 Subject: [PATCH 374/469] Internal change PiperOrigin-RevId: 501594400 --- mediapipe/framework/deps/registration.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index 1a33b2b24..9d80aafea 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -253,7 +253,7 @@ class FunctionRegistry { if (names[0].empty()) { names.erase(names.begin()); } else { - CHECK_EQ(1, names.size()) + CHECK_EQ(1u, names.size()) << "A registered class name must be either fully qualified " << "with a leading :: or unqualified, got: " << name << "."; } From 1683d572ed778c444d8c8e1b7f9f9a240a65667e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 12 Jan 2023 10:20:09 -0800 Subject: [PATCH 375/469] Internal change PiperOrigin-RevId: 501600938 --- mediapipe/gpu/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index cc5e50dfc..9074daf61 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -1021,8 +1021,8 @@ objc_library( visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", + ":copy_calculator_cc_proto", ":simple_shaders_mtl", - "//mediapipe/gpu:copy_calculator_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", "//third_party/apple_frameworks:CoreVideo", "//third_party/apple_frameworks:Metal", From 8156da341833e9d6c8042a0cef1494770458c8a0 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 12 Jan 2023 13:52:27 -0800 Subject: [PATCH 376/469] ClassificationAggregationCalculator should fill in the `timestamp_ms` field of the classification results in the stream mode. Per user feedback, the consistency between the packet timestamp and the timestamp field of the classification result helps reducing the confusion. PiperOrigin-RevId: 501657922 --- .../audio_classifier/audio_classifier_test.cc | 3 +- .../classification_aggregation_calculator.cc | 1 + ...ssification_aggregation_calculator_test.cc | 7 ++- ...lassification_postprocessing_graph_test.cc | 4 ++ .../test/vision/image_classifier_test.py | 57 ++++++++++++------- 5 files changed, 46 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 596b910f8..2d5b221a9 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -143,8 +143,9 @@ void CheckStreamingModeResults(std::vector outputs) { EXPECT_EQ(outputs.size(), 5); // Ignore last result, which operates on a too small chunk to return relevant // results. + std::vector timestamps_ms = {0, 975, 1950, 2925}; for (int i = 0; i < outputs.size() - 1; i++) { - EXPECT_FALSE(outputs[i].timestamp_ms.has_value()); + EXPECT_EQ(outputs[i].timestamp_ms.value(), timestamps_ms[i]); EXPECT_EQ(outputs[i].classifications.size(), 1); EXPECT_EQ(outputs[i].classifications[0].head_index, 0); EXPECT_EQ(outputs[i].classifications[0].head_name, "scores"); diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index ad2c668c3..145076cd3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -188,6 +188,7 @@ ClassificationAggregationCalculator::ConvertToClassificationResult( *classifications->mutable_classification_list() = std::move(classification_lists[i]); } + result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000); cached_classifications_.erase(cc->InputTimestamp().Value()); return result; } diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc index 1bc8cafd6..811d70544 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc @@ -150,14 +150,15 @@ class ClassificationAggregationCalculatorTest CalculatorGraph calculator_graph_; }; -TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) { +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph()); MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( - R"pb(classifications { + R"pb(timestamp_ms: 0, + classifications { head_index: 0 head_name: "foo" classification_list { classification { index: 0 } } @@ -169,7 +170,7 @@ TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) { })pb"))); } -TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) { +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); MP_ASSERT_OK(Send( diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index 8eb6f3c3b..a11bad71a 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -534,6 +534,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { // Validate results. EXPECT_THAT(results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 classification_list { @@ -567,6 +568,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "probability" @@ -603,6 +605,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "probability" @@ -646,6 +649,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "yamnet_classification" diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index cbeaf36bd..b47efb32b 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -61,7 +61,7 @@ def _generate_empty_results() -> ImageClassifierResult: timestamp_ms=0) -def _generate_burger_results() -> ImageClassifierResult: +def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult: return ImageClassifierResult( classifications=[ _Classifications( @@ -70,30 +70,36 @@ def _generate_burger_results() -> ImageClassifierResult: index=934, score=0.793959, display_name='', - category_name='cheeseburger'), + category_name='cheeseburger', + ), _Category( index=932, score=0.0273929, display_name='', - category_name='bagel'), + category_name='bagel', + ), _Category( index=925, score=0.0193408, display_name='', - category_name='guacamole'), + category_name='guacamole', + ), _Category( index=963, score=0.00632786, display_name='', - category_name='meat loaf') + category_name='meat loaf', + ), ], head_index=0, - head_name='probability') + head_name='probability', + ) ], - timestamp_ms=0) + timestamp_ms=timestamp_ms, + ) -def _generate_soccer_ball_results() -> ImageClassifierResult: +def _generate_soccer_ball_results(timestamp_ms=0) -> ImageClassifierResult: return ImageClassifierResult( classifications=[ _Classifications( @@ -102,12 +108,15 @@ def _generate_soccer_ball_results() -> ImageClassifierResult: index=806, score=0.996527, display_name='', - category_name='soccer ball') + category_name='soccer ball', + ) ], head_index=0, - head_name='probability') + head_name='probability', + ) ], - timestamp_ms=0) + timestamp_ms=timestamp_ms, + ) class ModelFileType(enum.Enum): @@ -379,8 +388,11 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - test_utils.assert_proto_equals(self, classification_result.to_pb2(), - _generate_burger_results().to_pb2()) + test_utils.assert_proto_equals( + self, + classification_result.to_pb2(), + _generate_burger_results(timestamp).to_pb2(), + ) def test_classify_for_video_succeeds_with_region_of_interest(self): options = _ImageClassifierOptions( @@ -398,8 +410,11 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( test_image, timestamp, image_processing_options) - test_utils.assert_proto_equals(self, classification_result.to_pb2(), - _generate_soccer_ball_results().to_pb2()) + test_utils.assert_proto_equals( + self, + classification_result.to_pb2(), + _generate_soccer_ball_results(timestamp).to_pb2(), + ) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( @@ -455,8 +470,7 @@ class ImageClassifierTest(parameterized.TestCase): score_threshold=threshold, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: - for timestamp in range(0, 300, 30): - classifier.classify_async(self.test_image, timestamp) + classifier.classify_async(self.test_image, 0) def test_classify_async_succeeds_with_region_of_interest(self): # Load the test image. @@ -470,8 +484,9 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: ImageClassifierResult, output_image: _Image, timestamp_ms: int): - test_utils.assert_proto_equals(self, result.to_pb2(), - _generate_soccer_ball_results().to_pb2()) + test_utils.assert_proto_equals( + self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2() + ) self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) @@ -483,9 +498,7 @@ class ImageClassifierTest(parameterized.TestCase): max_results=1, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: - for timestamp in range(0, 300, 30): - classifier.classify_async(test_image, timestamp, - image_processing_options) + classifier.classify_async(test_image, 100, image_processing_options) if __name__ == '__main__': From 5642980ab01466b1fce7c1abad701ba2f0f13a76 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:03 +0530 Subject: [PATCH 377/469] Updated iOS error implementation to mimic java --- .../tasks/ios/common/sources/MPPCommon.h | 159 +++--------------- .../common/utils/sources/MPPCommonUtils.mm | 125 +++++++++----- 2 files changed, 104 insertions(+), 180 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 09a61e20d..f8047fc35 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -25,153 +25,44 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { // Generic error codes. - // Unspecified error. - MPPTasksErrorCodeError = 1, - // Invalid argument specified. - MPPTasksErrorCodeInvalidArgumentError = 2, - // Invalid FlatBuffer file or buffer specified. - MPPTasksErrorCodeInvalidFlatBufferError = 3, - // Model contains a builtin op that isn't supported by the OpResolver or - // delegates. - MPPTasksErrorCodeUnsupportedBuiltinOp = 4, - // Model contains a custom op that isn't supported by the OpResolver or - // delegates. - MPPTasksErrorCodeUnsupportedCustomOp = 5, + /** Indicates the operation was cancelled, typically by the caller. */ + MPPTasksErrorCodeCancelledError = 1, + /** Indicates an unknown error occurred. */ + MPPTasksErrorCodeUnknownError = 2, + /** Indicates the caller specified an invalid argument, such as a malformed filename. */ + MPPTasksErrorCodeInvalidArgumentError = 3, + /** Indicates a deadline expired before the operation could complete. */ + MPPTasksErrorCodeDeadlineExceededError = 4, + /** Indicates some requested entity (such as a file or directory) was not found. */ + MPPTasksErrorCodeNotFoundError = 5, + /** Indicates that the entity a caller attempted to create (such as a file or directory) is already present. */ + MPPTasksErrorCodeAlreadyExistsError = 6, + /** Indicates that the caller does not have permission to execute the specified operation. */ + MPPTasksErrorCodePermissionDeniedError = 7, - // File I/O error codes. + MPPTasksErrorCodeResourceExhaustedError = 8, - // No such file. - MPPTasksErrorCodeFileNotFoundError = 100, - // Permission issue. - MPPTasksErrorCodeFilePermissionDeniedError, - // I/O error when reading file. - MPPTasksErrorCodeFileReadError, - // I/O error when mmap-ing file. - MPPTasksErrorCodeFileMmapError, - // ZIP I/O error when unpacking the zip file. - MPPTasksErrorCodeFileZipError, + MPPTasksErrorCodeFailedPreconditionError = 9, - // TensorFlow Lite metadata error codes. + MPPTasksErrorCodeAbortedError = 10, - // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. - MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, - // No such associated file within metadata, or file has not been packed. - MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, - // ZIP I/O error when unpacking an associated file. - MPPTasksErrorCodeMetadataAssociatedFileZipError, - // Inconsistency error between the metadata and actual TF Lite model. - // E.g.: number of labels and output tensor values differ. - MPPTasksErrorCodeMetadataInconsistencyError, - // Invalid process units specified. - // E.g.: multiple ProcessUnits with the same type for a given tensor. - MPPTasksErrorCodeMetadataInvalidProcessUnitsError, - // Inconsistency error with the number of labels. - // E.g.: label files for different locales have a different number of labels. - MPPTasksErrorCodeMetadataNumLabelsMismatchError, - // Score calibration parameters parsing error. - // E.g.: too many parameters provided in the corresponding associated file. - MPPTasksErrorCodeMetadataMalformedScoreCalibrationError, - // Unexpected number of subgraphs for the current task. - // E.g.: image classification expects a single subgraph. - MPPTasksErrorCodeMetadataInvalidNumSubgraphsError, - // A given tensor requires NormalizationOptions but none were found. - // E.g.: float input tensor requires normalization to preprocess input images. - MPPTasksErrorCodeMetadataMissingNormalizationOptionsError, - // Invalid ContentProperties specified. - // E.g. expected ImageProperties, got BoundingBoxProperties. - MPPTasksErrorCodeMetadataInvalidContentPropertiesError, - // Metadata is mandatory but was not found. - // E.g. current task requires TFLite Model Metadata but none was found. - MPPTasksErrorCodeMetadataNotFoundError, - // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but - // none was found or it was empty. - // E.g. current task requires labels but none were found. - MPPTasksErrorCodeMetadataMissingLabelsError, - // The ProcessingUnit for tokenizer is not correctly configured. - // E.g BertTokenizer doesn't have a valid vocab file associated. - MPPTasksErrorCodeMetadataInvalidTokenizerError, + MPPTasksErrorCodeOutOfRangeError = 11, - // Input tensor(s) error codes. + MPPTasksErrorCodeUnimplementedError = 12, - // Unexpected number of input tensors for the current task. - // E.g. current task expects a single input tensor. - MPPTasksErrorCodeInvalidNumInputTensorsError = 300, - // Unexpected input tensor dimensions for the current task. - // E.g.: only 4D input tensors supported. - MPPTasksErrorCodeInvalidInputTensorDimensionsError, - // Unexpected input tensor type for the current task. - // E.g.: current task expects a uint8 pixel image as input. - MPPTasksErrorCodeInvalidInputTensorTypeError, - // Unexpected input tensor bytes size. - // E.g.: size in bytes does not correspond to the expected number of pixels. - MPPTasksErrorCodeInvalidInputTensorSizeError, - // No correct input tensor found for the model. - // E.g.: input tensor name is not part of the text model's input tensors. - MPPTasksErrorCodeInputTensorNotFoundError, + MPPTasksErrorCodeInternalError = 13, - // Output tensor(s) error codes. + MPPTasksErrorCodeUnavailableError = 14, - // Unexpected output tensor dimensions for the current task. - // E.g.: only a batch size of 1 is supported. - MPPTasksErrorCodeInvalidOutputTensorDimensionsError = 400, - // Unexpected input tensor type for the current task. - // E.g.: multi-head model with different output tensor types. - MPPTasksErrorCodeInvalidOutputTensorTypeError, - // No correct output tensor found for the model. - // E.g.: output tensor name is not part of the text model's output tensors. - MPPTasksErrorCodeOutputTensorNotFoundError, - // Unexpected number of output tensors for the current task. - // E.g.: current task expects a single output tensor. - MPPTasksErrorCodeInvalidNumOutputTensorsError, + MPPTasksErrorCodeDataLossError = 15, - // Image processing error codes. - - // Unspecified image processing failures. - MPPTasksErrorCodeImageProcessingError = 500, - // Unexpected input or output buffer metadata. - // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. - MPPTasksErrorCodeImageProcessingInvalidArgumentError, - // Image processing operation failures. - // E.g. libyuv rotation failed for an unknown reason. - MPPTasksErrorCodeImageProcessingBackendError, - - // Task runner error codes. - MPPTasksErrorCodeRunnerError = 600, - // Task runner is not initialized. - MPPTasksErrorCodeRunnerInitializationError, - // Task runner is not started successfully. - MPPTasksErrorCodeRunnerFailsToStartError, - // Task runner is not started. - MPPTasksErrorCodeRunnerNotStartedError, - // Task runner API is called in the wrong processing mode. - MPPTasksErrorCodeRunnerApiCalledInWrongModeError, - // Task runner receives/produces invalid MediaPipe packet timestamp. - MPPTasksErrorCodeRunnerInvalidTimestampError, - // Task runner receives unexpected MediaPipe graph input packet. - // E.g. The packet type doesn't match the graph input stream's data type. - MPPTasksErrorCodeRunnerUnexpectedInputError, - // Task runner produces unexpected MediaPipe graph output packet. - // E.g. The number of output packets is not equal to the number of graph - // output streams. - MPPTasksErrorCodeRunnerUnexpectedOutputError, - // Task runner is not closed successfully. - MPPTasksErrorCodeRunnerFailsToCloseError, - // Task runner's model resources cache service is unavailable or the - // targeting model resources bundle is not found. - MPPTasksErrorCodeRunnerModelResourcesCacheServiceError, - - // Task graph error codes. - MPPTasksErrorCodeGraphError = 700, - // Task graph is not implemented. - MPPTasksErrorCodeTaskGraphNotImplementedError, - // Task graph config is invalid. - MPPTasksErrorCodeInvalidTaskGraphConfigError, + MPPTasksErrorCodeUnauthenticatedError = 16, // The first error code in MPPTasksErrorCode (for internal use only). - MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, + MPPTasksErrorCodeFirst = MPPTasksErrorCodeCancelledError, // The last error code in MPPTasksErrorCode (for internal use only). - MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, + MPPTasksErrorCodeLast = MPPTasksErrorCodeUnauthenticatedError, } NS_SWIFT_NAME(TasksErrorCode); diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 1a37f8465..9932dd13c 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -25,6 +25,10 @@ /** Error domain of MediaPipe task library errors. */ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; +namespace { + using absl::StatusCode; +} + @implementation MPPCommonUtils + (void)createCustomError:(NSError **)error @@ -67,52 +71,6 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of - // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum - // stored in the payload is extracted here to later map to the appropriate error code to be - // returned. In cases where the enum is not stored in (payload is NULL or the payload string - // cannot be converted to an integer), we set the error code value to be 1 - // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify - // any errors not falling into other categories.) Since payload is of type absl::Cord that can be - // type cast into an absl::optional, we use the std::stoi function to convert it into - // an integer code if possible. - NSUInteger genericErrorCode = MPPTasksErrorCodeError; - NSUInteger errorCode; - try { - // Try converting payload to integer if payload is not empty. Otherwise convert a string - // signifying generic error code MPPTasksErrorCodeError to integer. - errorCode = - (NSUInteger)std::stoi(static_cast>( - status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload)) - .value_or(std::to_string(genericErrorCode))); - } catch (std::invalid_argument &e) { - // If non empty payload string cannot be converted to an integer. Set error code to 1(kError). - errorCode = MPPTasksErrorCodeError; - } - - // If errorCode is outside the range of enum values possible or is - // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign - // appropriate MPPTasksErrorCode in default cases. Note: - // The mapping to absl::Status::code() is done to generate a more specific error code than - // MPPTasksErrorCodeError in cases when the payload can't be mapped to - // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn - // returned without modification by MediaPipe cc library methods. - if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { - switch (status.code()) { - case absl::StatusCode::kInternal: - errorCode = MPPTasksErrorCodeError; - break; - case absl::StatusCode::kInvalidArgument: - errorCode = MPPTasksErrorCodeInvalidArgumentError; - break; - case absl::StatusCode::kNotFound: - errorCode = MPPTasksErrorCodeError; - break; - default: - errorCode = MPPTasksErrorCodeError; - break; - } - } // Creates the NSEror with the appropriate error // MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one @@ -129,6 +87,81 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; NSString *description = [NSString stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; + + // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of + // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum + // stored in the payload is extracted here to later map to the appropriate error code to be + // returned. In cases where the enum is not stored in (payload is NULL or the payload string + // cannot be converted to an integer), we set the error code value to be 1 + // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify + // any errors not falling into other categories.) Since payload is of type absl::Cord that can be + // type cast into an absl::optional, we use the std::stoi function to convert it into + // an integer code if possible. + MPPTasksErrorCode genericErrorCode = MPPTasksErrorCodeUnknownError; + + MPPTasksErrorCode errorCode = genericErrorCode; + + // If errorCode is outside the range of enum values possible or is + // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign + // appropriate MPPTasksErrorCode in default cases. Note: + // The mapping to absl::Status::code() is done to generate a more specific error code than + // MPPTasksErrorCodeError in cases when the payload can't be mapped to + // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn + // returned without modification by MediaPipe cc library methods. + switch (status.code()) { + case StatusCode::kCancelled: + errorCode = MPPTasksErrorCodeCancelledError; + break; + case StatusCode::kUnknown: + errorCode = MPPTasksErrorCodeUnknownError; + break; + case StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case StatusCode::kDeadlineExceeded: + errorCode = MPPTasksErrorCodeDeadlineExceededError; + break; + case StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeNotFoundError; + break; + case StatusCode::kAlreadyExists: + errorCode = MPPTasksErrorCodeAlreadyExistsError; + break; + case StatusCode::kPermissionDenied: + errorCode = MPPTasksErrorCodePermissionDeniedError; + break; + case StatusCode::kResourceExhausted: + errorCode = MPPTasksErrorCodeResourceExhaustedError; + break; + case StatusCode::kFailedPrecondition: + errorCode = MPPTasksErrorCodeFailedPreconditionError; + break; + case StatusCode::kAborted: + errorCode = MPPTasksErrorCodeAbortedError; + break; + case StatusCode::kOutOfRange: + errorCode = MPPTasksErrorCodeOutOfRangeError; + break; + case StatusCode::kUnimplemented: + errorCode = MPPTasksErrorCodeUnimplementedError; + break; + case StatusCode::kInternal: + errorCode = MPPTasksErrorCodeInternalError; + break; + case StatusCode::kUnavailable: + errorCode = MPPTasksErrorCodeUnavailableError; + break; + case StatusCode::kDataLoss: + errorCode = MPPTasksErrorCodeDataLossError; + break; + case StatusCode::kUnauthenticated: + errorCode = MPPTasksErrorCodeUnauthenticatedError; + break; + default: + errorCode = genericErrorCode; + break; + } + [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; return NO; } From fa30100059330e9498469e4ca5065686a2079ee7 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:17 +0530 Subject: [PATCH 378/469] Changed swift name of MPPCategory --- mediapipe/tasks/ios/components/containers/sources/MPPCategory.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h index d05cfe13b..f360d46da 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN * index of the label in the corresponding label file. Typically it's used as the result of * classification tasks. */ -NS_SWIFT_NAME(ClassificationCategory) +NS_SWIFT_NAME(ResultCategory) @interface MPPCategory : NSObject /** From 0a707256e3b6a993447bf9b6206688e1e6bb58f0 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:43 +0530 Subject: [PATCH 379/469] Updates to method signatures of iOS text classifier --- .../ios/text/text_classifier/sources/MPPTextClassifier.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 60aa94614..e33615dab 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -65,7 +65,7 @@ NS_SWIFT_NAME(TextClassifier) * @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an * error in initializing the text classifier. */ -- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; /** * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. @@ -78,7 +78,7 @@ NS_SWIFT_NAME(TextClassifier) * @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an * error in initializing the text classifier. */ -- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options +- (nullable instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** @@ -90,7 +90,8 @@ NS_SWIFT_NAME(TextClassifier) * * @return A `MPPTextClassifierResult` object that contains a list of text classifications. */ -- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error; +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error NS_SWIFT_NAME(classify(text:)); + - (instancetype)init NS_UNAVAILABLE; From c40356c62852f9f174b04d790303511d8264fcef Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:56 +0530 Subject: [PATCH 380/469] Added ios.bzl --- mediapipe/tasks/ios/ios.bzl | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 mediapipe/tasks/ios/ios.bzl diff --git a/mediapipe/tasks/ios/ios.bzl b/mediapipe/tasks/ios/ios.bzl new file mode 100644 index 000000000..8fe2a24a1 --- /dev/null +++ b/mediapipe/tasks/ios/ios.bzl @@ -0,0 +1,3 @@ +"""MediaPipe Task Library Helper Rules for iOS""" + +MPP_TASK_MINIMUM_OS_VERSION = "11.0" From 9e0b85c9b58b0395442c3e8cdeee46e45c8af380 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:05:17 +0530 Subject: [PATCH 381/469] Added module name for iOS text classifier --- mediapipe/tasks/ios/text/text_classifier/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index aef68c9fe..1afddb5d4 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -46,6 +46,7 @@ objc_library( "-std=c++17", "-x objective-c++", ], + module_name = "MPPTextClassifier", deps = [ ":MPPTextClassifierOptions", ":MPPTextClassifierResult", From 2a53d78ae44bf27ec81ef795e51a5eb6fb863398 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:05:44 +0530 Subject: [PATCH 382/469] Added swift and objective tests for iOS text classifier --- .../tasks/ios/test/text/text_classifier/BUILD | 82 +++++ .../text_classifier/MPPTextClassifierTests.m | 281 ++++++++++++++++++ .../text_classifier/TextClassifierTests.swift | 237 +++++++++++++++ 3 files changed, 600 insertions(+) create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/BUILD create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD new file mode 100644 index 000000000..b69202b64 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -0,0 +1,82 @@ +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_unit_test", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner" +) +load( + "@build_bazel_rules_swift//swift:swift.bzl", + "swift_library" +) +load( + "//mediapipe/tasks:ios/ios.bzl", + "MPP_TASK_MINIMUM_OS_VERSION" +) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +TFL_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +TFL_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] + +objc_library( + name = "MPPTextClassifierObjcTestLibrary", + testonly = 1, + srcs = ["MPPTextClassifierTests.m"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = [], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], + +) + +ios_unit_test( + name = "MPPTextClassifierObjcTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags =[], + deps = [ + ":MPPTextClassifierObjcTestLibrary", + ], +) + +swift_library( + name = "MPPTextClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["TextClassifierTests.swift"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], +) + +ios_unit_test( + name = "MPPTextClassifierSwiftTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPTextClassifierSwiftTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m new file mode 100644 index 000000000..3e2fe4bef --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -0,0 +1,281 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kRegexTextClassifierModelName = + @"test_model_text_classifier_with_regex_tokenizer"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; +static NSString *const kPositiveText = @"it's a charming and often affecting journey"; +static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; + +#define AssertEqualErrors(error, expectedError) \ + XCTAssertNotNil(error); \ + XCTAssertEqualObjects(error.domain, expectedError.domain); \ + XCTAssertEqual(error.code, expectedError.code); \ + XCTAssertNotEqual( \ + [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ + NSNotFound) + +#define AssertEqualCategoryArrays(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \ + } + +#define AssertTextClassifierResultHasOneHead(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); \ + \ + XCTAssertNotNil(textClassifierResult.classificationResult); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + +@interface MPPTextClassifierTests : XCTestCase +@end + +@implementation MPPTextClassifierTests + +- (void)setUp { +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each test method in the + // class. +} + ++ (NSArray *)expectedBertResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil], + [[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForEdgeCaseTests { + return @[ [[MPPCategory alloc] initWithIndex:0 + score:0.956187f + categoryName:@"negative" + displayName:nil] ]; +} + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + return filePath; +} + +- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions; +} + +- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:nil]; + XCTAssertNotNil(textClassifier); + + return textClassifier; +} + +- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions + failsWithExpectedError:(NSError *)expectedError { + NSError *error = nil; + MPPTextClassifier *textClassifier = + [[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error]; + XCTAssertNil(textClassifier); + AssertEqualErrors(error, expectedError); +} + +- (void)assertResultsOfClassifyText:(NSString *)text + usingTextClassifier:(MPPTextClassifier *)textClassifier + equalsCategories:(NSArray *)expectedCategories { + MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil]; + AssertTextClassifierResultHasOneHead(negativeResult); + AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories, + expectedCategories); +} + +- (void)testCreateTextClassifierFailsWithMissingModelPath { + NSString *modelPath = [self filePathWithName:@"" extension:@""]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:&error]; + XCTAssertNil(textClassifier); + + NSError *expectedError = [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " + @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." + }]; + AssertEqualErrors(error, expectedError); +} + +- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"positive" ]; + options.categoryDenylist = @[ @"negative" ]; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: `category_allowlist` and " + @"`category_denylist` are mutually exclusive options." + }]]; +} + +- (void)testCreateTextClassifierFailsWithInvalidMaxResults { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 0; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: Invalid `max_results` option: " + @"value must be != 0." + }]]; +} + +- (void)testClassifyWithBertSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kBertTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForNegativeText]]; + + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithRegexSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kRegexTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForNegativeText]]; + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithMaxResultsSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 1; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryAllowListSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"negative" ]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options + error:&error]; + XCTAssertNotNil(textClassifier); + XCTAssertNil(error); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryDenyListSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryDenylist = @[ @"positive" ]; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithScoreThresholdSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.scoreThreshold = 0.5f; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +@end diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift new file mode 100644 index 000000000..d2d433c22 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -0,0 +1,237 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +import MPPCommon + +@testable import MPPTextClassifier + +class TextClassifierTests: XCTestCase { + + static let bundle = Bundle(for: TextClassifierTests.self) + + static let kBertModelPath = bundle.path( + forResource: "bert_text_classifier", + ofType: "tflite") + + static let kPositiveText = "it's a charming and often affecting journey" + + static let kNegativeText = "unflinchingly bleak and desperate" + + static let kBertNegativeTextResults = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil), + ResultCategory( + index: 1, + score: 0.043812, + categoryName: "positive", + displayName: nil) + ] + + static let kBertNegativeTextResultsForEdgeTestCases = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil), + ] + + func assertEqualErrorDescriptions( + _ error: Error, expectedLocalizedDescription:String) { + XCTAssertEqual( + error.localizedDescription, + expectedLocalizedDescription) + } + + func assertCategoriesAreEqual( + category: ResultCategory, + expectedCategory: ResultCategory) { + XCTAssertEqual( + category.index, + expectedCategory.index) + XCTAssertEqual( + category.score, + expectedCategory.score, + accuracy:1e-6) + XCTAssertEqual( + category.categoryName, + expectedCategory.categoryName) + XCTAssertEqual( + category.displayName, + expectedCategory.displayName) + } + + func assertEqualCategoryArrays( + categoryArray: [ResultCategory], + expectedCategoryArray:[ResultCategory]) { + + XCTAssertEqual(categoryArray.count, expectedCategoryArray.count) + + for (category, expectedCategory) in + zip(categoryArray, expectedCategoryArray) { + assertCategoriesAreEqual( + category:category, + expectedCategory:expectedCategory) + } + } + + func assertTextClassifierResultHasOneHead( + _ textClassifierResult: TextClassifierResult) { + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + } + + func textClassifierOptionsWithModelPath( + _ modelPath: String?) throws -> TextClassifierOptions { + let modelPath = try XCTUnwrap(modelPath) + + let textClassifierOptions = TextClassifierOptions(); + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions + } + + func assertCreateTextClassifierThrowsError( + textClassifierOptions: TextClassifierOptions, + expectedErrorDescription: String) { + do { + let textClassifier = try TextClassifier(options:textClassifierOptions) + XCTAssertNil(textClassifier) + } + catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: expectedErrorDescription) + } + } + + func assertResultsForClassify( + text: String, + using textClassifier: TextClassifier, + equals expectedCategories: [ResultCategory]) throws { + let textClassifierResult = + try XCTUnwrap( + textClassifier.classify(text: text)); + assertTextClassifierResultHasOneHead(textClassifierResult); + assertEqualCategoryArrays( + categoryArray: + textClassifierResult.classificationResult.classifications[0].categories, + expectedCategoryArray: expectedCategories); + } + + func testCreateTextClassifierWithInvalidMaxResultsFails() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.maxResults = 0 + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0. + """) + } + + func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws { + + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryAllowlist = ["positive"] + textClassifierOptions.categoryDenylist = ["positive"] + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \ + mutually exclusive options. + """) + } + + func testClassifyWithBertSucceeds() throws { + + let modelPath = try XCTUnwrap(TextClassifierTests.kBertModelPath) + let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResults) + } + + func testClassifyWithMaxResultsSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.maxResults = 1 + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryAllowlistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryAllowlist = ["negative"]; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryDenylistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryDenylist = ["positive"]; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithScoreThresholdSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.scoreThreshold = 0.5; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + +} From c4c07acc1e5b2dbc37965b9c714fad2102705dbd Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:18:01 +0530 Subject: [PATCH 383/469] Updated comments of MPPCommonUtils --- .../common/utils/sources/MPPCommonUtils.mm | 141 +++++++----------- 1 file changed, 58 insertions(+), 83 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 9932dd13c..27b75515d 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -26,7 +26,7 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; namespace { - using absl::StatusCode; +using absl::StatusCode; } @implementation MPPCommonUtils @@ -72,95 +72,70 @@ namespace { return YES; } - // Creates the NSEror with the appropriate error - // MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one - // mapping with MediaPipeTasksStatus starting from the value 1(MPPTasksErrorCodeError) - // and hence will be correctly initialized if directly cast from the integer code derived from - // MediaPipeTasksStatus stored in its payload. MPPTasksErrorCode omits kOk = 0 of - // MediaPipeTasksStatusx. - // - // Stores a string including absl status code and message(if non empty) as the - // error message See - // https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514 - // for explanation. absl::Status::message() can also be used but not always - // guaranteed to be non empty. + /** Converts the absl status message to an NSString. */ NSString *description = [NSString stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; - - // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of - // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum - // stored in the payload is extracted here to later map to the appropriate error code to be - // returned. In cases where the enum is not stored in (payload is NULL or the payload string - // cannot be converted to an integer), we set the error code value to be 1 - // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify - // any errors not falling into other categories.) Since payload is of type absl::Cord that can be - // type cast into an absl::optional, we use the std::stoi function to convert it into - // an integer code if possible. + MPPTasksErrorCode genericErrorCode = MPPTasksErrorCodeUnknownError; MPPTasksErrorCode errorCode = genericErrorCode; - // If errorCode is outside the range of enum values possible or is - // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign - // appropriate MPPTasksErrorCode in default cases. Note: - // The mapping to absl::Status::code() is done to generate a more specific error code than - // MPPTasksErrorCodeError in cases when the payload can't be mapped to - // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn - // returned without modification by MediaPipe cc library methods. - switch (status.code()) { - case StatusCode::kCancelled: - errorCode = MPPTasksErrorCodeCancelledError; - break; - case StatusCode::kUnknown: - errorCode = MPPTasksErrorCodeUnknownError; - break; - case StatusCode::kInvalidArgument: - errorCode = MPPTasksErrorCodeInvalidArgumentError; - break; - case StatusCode::kDeadlineExceeded: - errorCode = MPPTasksErrorCodeDeadlineExceededError; - break; - case StatusCode::kNotFound: - errorCode = MPPTasksErrorCodeNotFoundError; - break; - case StatusCode::kAlreadyExists: - errorCode = MPPTasksErrorCodeAlreadyExistsError; - break; - case StatusCode::kPermissionDenied: - errorCode = MPPTasksErrorCodePermissionDeniedError; - break; - case StatusCode::kResourceExhausted: - errorCode = MPPTasksErrorCodeResourceExhaustedError; - break; - case StatusCode::kFailedPrecondition: - errorCode = MPPTasksErrorCodeFailedPreconditionError; - break; - case StatusCode::kAborted: - errorCode = MPPTasksErrorCodeAbortedError; - break; - case StatusCode::kOutOfRange: - errorCode = MPPTasksErrorCodeOutOfRangeError; - break; - case StatusCode::kUnimplemented: - errorCode = MPPTasksErrorCodeUnimplementedError; - break; - case StatusCode::kInternal: - errorCode = MPPTasksErrorCodeInternalError; - break; - case StatusCode::kUnavailable: - errorCode = MPPTasksErrorCodeUnavailableError; - break; - case StatusCode::kDataLoss: - errorCode = MPPTasksErrorCodeDataLossError; - break; - case StatusCode::kUnauthenticated: - errorCode = MPPTasksErrorCodeUnauthenticatedError; - break; - default: - errorCode = genericErrorCode; - break; - } + /** Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits + * absl::StatusCode::kOk. */ + switch (status.code()) { + case StatusCode::kCancelled: + errorCode = MPPTasksErrorCodeCancelledError; + break; + case StatusCode::kUnknown: + errorCode = MPPTasksErrorCodeUnknownError; + break; + case StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case StatusCode::kDeadlineExceeded: + errorCode = MPPTasksErrorCodeDeadlineExceededError; + break; + case StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeNotFoundError; + break; + case StatusCode::kAlreadyExists: + errorCode = MPPTasksErrorCodeAlreadyExistsError; + break; + case StatusCode::kPermissionDenied: + errorCode = MPPTasksErrorCodePermissionDeniedError; + break; + case StatusCode::kResourceExhausted: + errorCode = MPPTasksErrorCodeResourceExhaustedError; + break; + case StatusCode::kFailedPrecondition: + errorCode = MPPTasksErrorCodeFailedPreconditionError; + break; + case StatusCode::kAborted: + errorCode = MPPTasksErrorCodeAbortedError; + break; + case StatusCode::kOutOfRange: + errorCode = MPPTasksErrorCodeOutOfRangeError; + break; + case StatusCode::kUnimplemented: + errorCode = MPPTasksErrorCodeUnimplementedError; + break; + case StatusCode::kInternal: + errorCode = MPPTasksErrorCodeInternalError; + break; + case StatusCode::kUnavailable: + errorCode = MPPTasksErrorCodeUnavailableError; + break; + case StatusCode::kDataLoss: + errorCode = MPPTasksErrorCodeDataLossError; + break; + case StatusCode::kUnauthenticated: + errorCode = MPPTasksErrorCodeUnauthenticatedError; + break; + default: + errorCode = genericErrorCode; + break; + } [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; return NO; From 95f9f0fb88c209b147b7822c102f5003c22d3c16 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:18:10 +0530 Subject: [PATCH 384/469] Updated formatting --- .../tasks/ios/common/sources/MPPCommon.h | 30 +++++++++++++++++-- .../sources/MPPTextClassifier.h | 6 ++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index f8047fc35..0f885a8c2 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -18,8 +18,7 @@ NS_ASSUME_NONNULL_BEGIN /** * @enum MPPTasksErrorCode - * This enum specifies error codes for MediaPipe Task Library. - * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray. + * This enum specifies error codes for errors thrown by iOS MediaPipe Task Library. */ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { @@ -27,35 +26,60 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { /** Indicates the operation was cancelled, typically by the caller. */ MPPTasksErrorCodeCancelledError = 1, + /** Indicates an unknown error occurred. */ MPPTasksErrorCodeUnknownError = 2, + /** Indicates the caller specified an invalid argument, such as a malformed filename. */ MPPTasksErrorCodeInvalidArgumentError = 3, + /** Indicates a deadline expired before the operation could complete. */ MPPTasksErrorCodeDeadlineExceededError = 4, + /** Indicates some requested entity (such as a file or directory) was not found. */ MPPTasksErrorCodeNotFoundError = 5, - /** Indicates that the entity a caller attempted to create (such as a file or directory) is already present. */ + + /** Indicates that the entity a caller attempted to create (such as a file or directory) is + already present. */ MPPTasksErrorCodeAlreadyExistsError = 6, + /** Indicates that the caller does not have permission to execute the specified operation. */ MPPTasksErrorCodePermissionDeniedError = 7, + /** Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire + file system is out of space. */ MPPTasksErrorCodeResourceExhaustedError = 8, + /** Indicates that the operation was rejected because the system is not in a state required for + the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir" + operation is applied to a non-directory, etc. */ MPPTasksErrorCodeFailedPreconditionError = 9, + /** Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer + check failure or a failed transaction. */ MPPTasksErrorCodeAbortedError = 10, + /** Indicates the operation was attempted past the valid range, such as seeking or reading past an + end-of-file. */ MPPTasksErrorCodeOutOfRangeError = 11, + /** Indicates the operation is not implemented or supported in this service. In this case, the + operation should not be re-attempted. */ MPPTasksErrorCodeUnimplementedError = 12, + /** Indicates an internal error has occurred and some invariants expected by the underlying system + have not been satisfied. This error code is reserved for serious errors. */ MPPTasksErrorCodeInternalError = 13, + /** Indicates the service is currently unavailable and that this is most likely a transient + condition. */ MPPTasksErrorCodeUnavailableError = 14, + /** Indicates that unrecoverable data loss or corruption has occurred. */ MPPTasksErrorCodeDataLossError = 15, + /** Indicates that the request does not have valid authentication credentials for the operation. + */ MPPTasksErrorCodeUnauthenticatedError = 16, // The first error code in MPPTasksErrorCode (for internal use only). diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index e33615dab..33d3c8970 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -79,7 +79,7 @@ NS_SWIFT_NAME(TextClassifier) * error in initializing the text classifier. */ - (nullable instancetype)initWithOptions:(MPPTextClassifierOptions *)options - error:(NSError **)error NS_DESIGNATED_INITIALIZER; + error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** * Performs classification on the input text. @@ -90,8 +90,8 @@ NS_SWIFT_NAME(TextClassifier) * * @return A `MPPTextClassifierResult` object that contains a list of text classifications. */ -- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error NS_SWIFT_NAME(classify(text:)); - +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(classify(text:)); - (instancetype)init NS_UNAVAILABLE; From 69757d7924f84dfbe50520bba1bf1fdd4f177f16 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 13 Jan 2023 09:03:46 -0800 Subject: [PATCH 385/469] Internal change PiperOrigin-RevId: 501862194 --- mediapipe/framework/BUILD | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 83346dad1..da8ef3b3e 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -96,7 +96,9 @@ mediapipe_proto_library( mediapipe_proto_library( name = "mediapipe_options_proto", srcs = ["mediapipe_options.proto"], - visibility = [":mediapipe_internal"], + visibility = [ + ":mediapipe_internal", + ], ) mediapipe_proto_library( From f997c0ab1a8bc69d0ef8760061a515313144af8c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 13 Jan 2023 09:52:07 -0800 Subject: [PATCH 386/469] Reject RegionOfInterest in not supported tasks PiperOrigin-RevId: 501872455 --- .../vision/core/vision_task_runner.test.ts | 41 +++++++++++++++---- .../web/vision/core/vision_task_runner.ts | 9 +++- .../gesture_recognizer/gesture_recognizer.ts | 2 +- .../gesture_recognizer_test.ts | 8 ++++ .../vision/hand_landmarker/hand_landmarker.ts | 2 +- .../hand_landmarker/hand_landmarker_test.ts | 8 ++++ .../image_classifier/image_classifier.ts | 2 +- .../vision/image_embedder/image_embedder.ts | 2 +- .../vision/object_detector/object_detector.ts | 2 +- .../object_detector/object_detector_test.ts | 8 ++++ 10 files changed, 70 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 4567134b8..4eb51afdb 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -41,14 +41,14 @@ class VisionTaskRunnerFake extends VisionTaskRunner { expectedImageSource?: ImageSource; expectedNormalizedRect?: NormalizedRect; - constructor() { + constructor(roiAllowed = true) { super( jasmine.createSpyObj([ 'addProtoToStream', 'addGpuBufferAsImageToStream', 'setAutoRenderToScreen', 'registerModelResourcesGraphService', 'finishProcessing' ]), - IMAGE_STREAM, NORM_RECT_STREAM); + IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed); this.fakeGraphRunner = this.graphRunner as unknown as jasmine.SpyObj; @@ -71,6 +71,9 @@ class VisionTaskRunnerFake extends VisionTaskRunner { expect(timestamp).toBe(TIMESTAMP); expect(imageSource).toBe(this.expectedImageSource!); }); + + // SetOptions with a modelAssetBuffer runs synchonously + void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}}); } protected override refreshGraph(): void {} @@ -108,28 +111,26 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - let visionTaskRunner: VisionTaskRunnerFake; - - beforeEach(async () => { + beforeEach(() => { addJasmineCustomFloatEqualityTester(); - visionTaskRunner = new VisionTaskRunnerFake(); - await visionTaskRunner.setOptions( - {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); // Clear running mode @@ -140,6 +141,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process images with video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(() => { visionTaskRunner.processImageData( @@ -148,6 +150,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process video with image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); // Use default for `useStreamMode` expect(() => { visionTaskRunner.processVideoData( @@ -163,6 +166,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); @@ -172,6 +176,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph with image processing options', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); @@ -184,6 +189,7 @@ describe('VisionTaskRunner', () => { describe('validates processing options', () => { it('with left > right', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -197,6 +203,7 @@ describe('VisionTaskRunner', () => { }); it('with top > bottom', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -210,6 +217,7 @@ describe('VisionTaskRunner', () => { }); it('with out of range values', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -222,7 +230,24 @@ describe('VisionTaskRunner', () => { }).toThrowError('Expected RectF values to be in [0,1].'); }); + + it('without region of interest support', () => { + const visionTaskRunner = + new VisionTaskRunnerFake(/* roiAllowed= */ false); + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('with non-90 degree rotation', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); }).toThrowError('Expected rotation to be a multiple of 90°.'); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 71cac920c..b3e8ed4db 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -42,13 +42,16 @@ export abstract class VisionTaskRunner extends TaskRunner { * @param normRectStreamName the name of the input normalized rect image * stream used to provide (mandatory) rotation and (optional) * region-of-interest. + * @param roiAllowed Whether this task supports Region-Of-Interest + * pre-processing * * @hideconstructor protected */ constructor( protected override readonly graphRunner: VisionGraphRunner, private readonly imageStreamName: string, - private readonly normRectStreamName: string) { + private readonly normRectStreamName: string, + private readonly roiAllowed: boolean) { super(graphRunner); } @@ -96,6 +99,10 @@ export abstract class VisionTaskRunner extends TaskRunner { const normalizedRect = new NormalizedRect(); if (imageProcessingOptions?.regionOfInterest) { + if (!this.roiAllowed) { + throw new Error('This task doesn\'t support region-of-interest.'); + } + const roi = imageProcessingOptions.regionOfInterest; if (roi.left >= roi.right || roi.top >= roi.bottom) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 1b7201b9a..beea263ce 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -126,7 +126,7 @@ export class GestureRecognizer extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index dfc252eb6..b2a2c0d72 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -250,6 +250,14 @@ describe('GestureRecognizer', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + gestureRecognizer.recognize( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index b51fb6a52..cd0459372 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -116,7 +116,7 @@ export class HandLandmarker extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 0abd1df27..5fd493424 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -203,6 +203,14 @@ describe('HandLandmarker', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + handLandmarker.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index cb2849cd8..071513b19 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -101,7 +101,7 @@ export class ImageClassifier extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ true); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 788646e6d..fdeb92f3f 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -104,7 +104,7 @@ export class ImageEmbedder extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ true); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 5741a3a0c..5b581432d 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -100,7 +100,7 @@ export class ObjectDetector extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index ceb96acb1..9dd64c0b6 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -170,6 +170,14 @@ describe('ObjectDetector', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + objectDetector.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { const detectionProtos: Uint8Array[] = []; From aef4cca40610ced2efd0ed45a465c43368f4a893 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 13 Jan 2023 13:45:47 -0800 Subject: [PATCH 387/469] Copy README.md to NPM package root PiperOrigin-RevId: 501929871 --- mediapipe/tasks/web/BUILD | 161 +--------------------------- mediapipe/tasks/web/audio.ts | 25 ----- mediapipe/tasks/web/audio/BUILD | 58 +++++++++- mediapipe/tasks/web/audio/index.ts | 14 ++- mediapipe/tasks/web/text.ts | 25 ----- mediapipe/tasks/web/text/BUILD | 56 +++++++++- mediapipe/tasks/web/text/index.ts | 14 ++- mediapipe/tasks/web/vision.ts | 35 ------ mediapipe/tasks/web/vision/BUILD | 56 +++++++++- mediapipe/tasks/web/vision/index.ts | 30 ++++-- 10 files changed, 216 insertions(+), 258 deletions(-) delete mode 100644 mediapipe/tasks/web/audio.ts delete mode 100644 mediapipe/tasks/web/text.ts delete mode 100644 mediapipe/tasks/web/vision.ts diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 02bd70dd0..ff947ef54 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -1,158 +1,5 @@ -# This contains the MediaPipe Tasks NPM package definitions. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") -load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") -load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") -load( - "//mediapipe/framework/tool:mediapipe_files.bzl", - "mediapipe_files", -) - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_files(srcs = [ - "wasm/audio_wasm_internal.js", - "wasm/audio_wasm_internal.wasm", - "wasm/audio_wasm_nosimd_internal.js", - "wasm/audio_wasm_nosimd_internal.wasm", - "wasm/text_wasm_internal.js", - "wasm/text_wasm_internal.wasm", - "wasm/text_wasm_nosimd_internal.js", - "wasm/text_wasm_nosimd_internal.wasm", - "wasm/vision_wasm_internal.js", - "wasm/vision_wasm_internal.wasm", - "wasm/vision_wasm_nosimd_internal.js", - "wasm/vision_wasm_nosimd_internal.wasm", +exports_files([ + "karma.conf.ts", + "package.json", + "rollup.config.mjs", ]) - -# Audio - -mediapipe_ts_library( - name = "audio_lib", - srcs = ["audio.ts"], - deps = ["//mediapipe/tasks/web/audio:audio_lib"], -) - -rollup_bundle( - name = "audio_bundle", - config_file = "rollup.config.mjs", - entry_point = "audio.ts", - format = "esm", - output_dir = False, - sourcemap = "false", - deps = [ - ":audio_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-terser", - "@npm//google-protobuf", - ], -) - -pkg_npm( - name = "audio_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "audio", - "__DESCRIPTION__": "MediaPipe Audio Tasks", - "__TYPES__": "audio.d.ts", - }, - tgz = "audio.tgz", - deps = [ - "wasm/audio_wasm_internal.js", - "wasm/audio_wasm_internal.wasm", - "wasm/audio_wasm_nosimd_internal.js", - "wasm/audio_wasm_nosimd_internal.wasm", - ":audio_bundle", - "//mediapipe/tasks/web/audio:README.md", - ], -) - -# Text - -mediapipe_ts_library( - name = "text_lib", - srcs = ["text.ts"], - deps = ["//mediapipe/tasks/web/text:text_lib"], -) - -rollup_bundle( - name = "text_bundle", - config_file = "rollup.config.mjs", - entry_point = "text.ts", - format = "esm", - output_dir = False, - sourcemap = "false", - deps = [ - ":text_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-terser", - "@npm//google-protobuf", - ], -) - -pkg_npm( - name = "text_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "text", - "__DESCRIPTION__": "MediaPipe Text Tasks", - "__TYPES__": "text.d.ts", - }, - tgz = "text.tgz", - deps = [ - "wasm/text_wasm_internal.js", - "wasm/text_wasm_internal.wasm", - "wasm/text_wasm_nosimd_internal.js", - "wasm/text_wasm_nosimd_internal.wasm", - ":text_bundle", - "//mediapipe/tasks/web/text:README.md", - ], -) - -# Vision - -mediapipe_ts_library( - name = "vision_lib", - srcs = ["vision.ts"], - deps = ["//mediapipe/tasks/web/vision:vision_lib"], -) - -rollup_bundle( - name = "vision_bundle", - config_file = "rollup.config.mjs", - entry_point = "vision.ts", - format = "esm", - output_dir = False, - sourcemap = "false", - deps = [ - ":vision_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-terser", - "@npm//google-protobuf", - ], -) - -pkg_npm( - name = "vision_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "vision", - "__DESCRIPTION__": "MediaPipe Vision Tasks", - "__TYPES__": "vision.d.ts", - }, - tgz = "vision_pkg.tgz", - deps = [ - "wasm/vision_wasm_internal.js", - "wasm/vision_wasm_internal.wasm", - "wasm/vision_wasm_nosimd_internal.js", - "wasm/vision_wasm_nosimd_internal.wasm", - ":vision_bundle", - "//mediapipe/tasks/web/vision:README.md", - ], -) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts deleted file mode 100644 index 2f4fb0315..000000000 --- a/mediapipe/tasks/web/audio.ts +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl, FilesetResolver as FilesetResolverImpl} from '../../tasks/web/audio/index'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const AudioClassifier = AudioClassifierImpl; -const AudioEmbedder = AudioEmbedderImpl; -const FilesetResolver = FilesetResolverImpl; - -export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 50a611f41..7e05263fe 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -1,11 +1,15 @@ # This contains the MediaPipe Audio Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) -exports_files(["README.md"]) - mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], @@ -16,3 +20,53 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core:fileset_resolver", ], ) + +mediapipe_files(srcs = [ + "wasm/audio_wasm_internal.js", + "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", +]) + +rollup_bundle( + name = "audio_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", + deps = [ + ":audio_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "audio_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "audio", + "__DESCRIPTION__": "MediaPipe Audio Tasks", + "__TYPES__": "audio.d.ts", + }, + tgz = "audio.tgz", + deps = [ + "wasm/audio_wasm_internal.js", + "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", + ":audio_bundle", + ":package_json", + ], +) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index dbad8c617..44fa7eb25 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; -export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; + +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts deleted file mode 100644 index 0636714b8..000000000 --- a/mediapipe/tasks/web/text.ts +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {FilesetResolver as FilesetResolverImpl, TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const TextClassifier = TextClassifierImpl; -const TextEmbedder = TextEmbedderImpl; - -export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 077b25645..6f019aca1 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -1,10 +1,21 @@ # This contains the MediaPipe Text Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) -exports_files(["README.md"]) +mediapipe_files(srcs = [ + "wasm/text_wasm_internal.js", + "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", +]) mediapipe_ts_library( name = "text_lib", @@ -16,3 +27,46 @@ mediapipe_ts_library( "//mediapipe/tasks/web/text/text_embedder", ], ) + +rollup_bundle( + name = "text_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", + deps = [ + ":text_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "text_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "text", + "__DESCRIPTION__": "MediaPipe Text Tasks", + "__TYPES__": "text.d.ts", + }, + tgz = "text.tgz", + deps = [ + "wasm/text_wasm_internal.js", + "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", + ":package_json", + ":text_bundle", + ], +) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index a28e4dd1c..2c9e6fead 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/text/text_classifier/text_classifier'; -export * from '../../../tasks/web/text/text_embedder/text_embedder'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; +import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts deleted file mode 100644 index f1ced59af..000000000 --- a/mediapipe/tasks/web/vision.ts +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {FilesetResolver as FilesetResolverImpl, GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const GestureRecognizer = GestureRecognizerImpl; -const HandLandmarker = HandLandmarkerImpl; -const ImageClassifier = ImageClassifierImpl; -const ImageEmbedder = ImageEmbedderImpl; -const ObjectDetector = ObjectDetectorImpl; - -export { - FilesetResolver, - GestureRecognizer, - HandLandmarker, - ImageClassifier, - ImageEmbedder, - ObjectDetector -}; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index ea022e900..76b0c084e 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -1,10 +1,21 @@ # This contains the MediaPipe Vision Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) -exports_files(["README.md"]) +mediapipe_files(srcs = [ + "wasm/vision_wasm_internal.js", + "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", +]) mediapipe_ts_library( name = "vision_lib", @@ -19,3 +30,46 @@ mediapipe_ts_library( "//mediapipe/tasks/web/vision/object_detector", ], ) + +rollup_bundle( + name = "vision_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", + deps = [ + ":vision_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "vision_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "vision", + "__DESCRIPTION__": "MediaPipe Vision Tasks", + "__TYPES__": "vision.d.ts", + }, + tgz = "vision_pkg.tgz", + deps = [ + "wasm/vision_wasm_internal.js", + "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", + ":package_json", + ":vision_bundle", + ], +) diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 0337a0f2f..e13f8183f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,9 +14,27 @@ * limitations under the License. */ -export * from '../../../tasks/web/vision/image_classifier/image_classifier'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -export * from '../../../tasks/web/vision/object_detector/object_detector'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; +import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; +import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ObjectDetector = ObjectDetectorImpl; + +export { + FilesetResolver, + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ObjectDetector +}; From 92a2e02ace78105cf120bb68a17c07df6dbd027f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 13 Jan 2023 17:02:59 -0800 Subject: [PATCH 388/469] Internal change PiperOrigin-RevId: 501971410 --- mediapipe/framework/deps/BUILD | 12 ++++++++++++ mediapipe/framework/profiler/BUILD | 1 + mediapipe/util/BUILD | 1 + 3 files changed, 14 insertions(+) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 7ff004f1e..7994aae75 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -88,6 +88,7 @@ cc_library( name = "message_matchers", testonly = True, hdrs = ["message_matchers.h"], + # Use this library through "mediapipe/framework/port:gtest_main". visibility = [ "//mediapipe/framework/port:__pkg__", @@ -145,6 +146,7 @@ cc_library( cc_library( name = "map_util", hdrs = ["map_util.h"], + # Use this library through "mediapipe/framework/port:map_util". visibility = ["//mediapipe/framework/port:__pkg__"], deps = ["//mediapipe/framework/port:logging"], @@ -180,6 +182,7 @@ cc_library( cc_library( name = "point", hdrs = ["point2.h"], + # Use this library through "mediapipe/framework/port:point". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -199,6 +202,7 @@ cc_library( cc_library( name = "rectangle", hdrs = ["rectangle.h"], + # Use this library through "mediapipe/framework/port:rectangle". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -239,6 +243,7 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], + # Use this library through "mediapipe/framework/port:singleton". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -249,6 +254,7 @@ cc_library( cc_library( name = "source_location", hdrs = ["source_location.h"], + # Use this library through "mediapipe/framework/port:source_location". visibility = ["//mediapipe/framework/port:__pkg__"], ) @@ -265,6 +271,7 @@ cc_library( "status_builder.h", "status_macros.h", ], + # Use this library through "mediapipe/framework/port:status". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -288,6 +295,7 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], + # Use this library through "mediapipe/framework/port:gtest_main". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -301,6 +309,7 @@ cc_library( name = "ret_check", srcs = ["ret_check.cc"], hdrs = ["ret_check.h"], + # Use this library through "mediapipe/framework/port:ret_check". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -321,6 +330,7 @@ cc_library( "//conditions:default": ["threadpool_pthread_impl.cc"], }), hdrs = ["threadpool.h"], + # Use this library through "mediapipe/framework/port:threadpool". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -335,6 +345,7 @@ cc_library( name = "topologicalsorter", srcs = ["topologicalsorter.cc"], hdrs = ["topologicalsorter.h"], + # Use this library through "mediapipe/framework/port:topologicalsorter". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -345,6 +356,7 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], + # Use this library through "mediapipe/framework/port:vector". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 3b6976fc8..7ebfd3b8c 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -284,6 +284,7 @@ cc_library( "//mediapipe:ios": ["profiler_resource_util_ios.cc"], }), hdrs = ["profiler_resource_util.h"], + # We use Objective-C++ on iOS. copts = select({ "//conditions:default": [], diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 55c1df59f..555569552 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -186,6 +186,7 @@ cc_library( hdrs = [ "resource_util.h", ], + # We use Objective-C++ on iOS. copts = select({ "//conditions:default": [], From 30533be321744ddea7f37fea0bf77298596b9b92 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:00:10 +0530 Subject: [PATCH 389/469] Reformatted comments --- .../tasks/ios/common/sources/MPPCommon.h | 57 ++++++++++++------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 0f885a8c2..3f0a1a7b9 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -39,53 +39,70 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { /** Indicates some requested entity (such as a file or directory) was not found. */ MPPTasksErrorCodeNotFoundError = 5, - /** Indicates that the entity a caller attempted to create (such as a file or directory) is - already present. */ + /** + * Indicates that the entity a caller attempted to create (such as a file or directory) is + * already present. + */ MPPTasksErrorCodeAlreadyExistsError = 6, /** Indicates that the caller does not have permission to execute the specified operation. */ MPPTasksErrorCodePermissionDeniedError = 7, - /** Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire - file system is out of space. */ + /** + * Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire + * file system is out of space. + */ MPPTasksErrorCodeResourceExhaustedError = 8, - /** Indicates that the operation was rejected because the system is not in a state required for - the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir" - operation is applied to a non-directory, etc. */ + /** + * Indicates that the operation was rejected because the system is not in a state required for + * the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir" + * operation is applied to a non-directory, etc. + */ MPPTasksErrorCodeFailedPreconditionError = 9, - /** Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer - check failure or a failed transaction. */ + /** + * Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer + * check failure or a failed transaction. + */ MPPTasksErrorCodeAbortedError = 10, - /** Indicates the operation was attempted past the valid range, such as seeking or reading past an - end-of-file. */ + /** + * Indicates the operation was attempted past the valid range, such as seeking or reading past an + * end-of-file. + */ MPPTasksErrorCodeOutOfRangeError = 11, - /** Indicates the operation is not implemented or supported in this service. In this case, the - operation should not be re-attempted. */ + /** + * Indicates the operation is not implemented or supported in this service. In this case, the + * operation should not be re-attempted. + */ MPPTasksErrorCodeUnimplementedError = 12, - /** Indicates an internal error has occurred and some invariants expected by the underlying system - have not been satisfied. This error code is reserved for serious errors. */ + /** + * Indicates an internal error has occurred and some invariants expected by the underlying system + * have not been satisfied. This error code is reserved for serious errors. + */ MPPTasksErrorCodeInternalError = 13, - /** Indicates the service is currently unavailable and that this is most likely a transient - condition. */ + /** + * Indicates the service is currently unavailable and that this is most likely a transient + * condition. + */ MPPTasksErrorCodeUnavailableError = 14, /** Indicates that unrecoverable data loss or corruption has occurred. */ MPPTasksErrorCodeDataLossError = 15, - /** Indicates that the request does not have valid authentication credentials for the operation. + /** + * Indicates that the request does not have valid authentication credentials for the operation. */ MPPTasksErrorCodeUnauthenticatedError = 16, - // The first error code in MPPTasksErrorCode (for internal use only). + /** The first error code in MPPTasksErrorCode (for internal use only). */ MPPTasksErrorCodeFirst = MPPTasksErrorCodeCancelledError, - // The last error code in MPPTasksErrorCode (for internal use only). + /** The last error code in MPPTasksErrorCode (for internal use only). */ MPPTasksErrorCodeLast = MPPTasksErrorCodeUnauthenticatedError, } NS_SWIFT_NAME(TasksErrorCode); From 8ecf77f760c49fd319b80a2bd5daefaba5a7cd72 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:02:33 +0530 Subject: [PATCH 390/469] Updated comment style in methods --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 27b75515d..538023df6 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -72,7 +72,7 @@ using absl::StatusCode; return YES; } - /** Converts the absl status message to an NSString. */ + // Converts the absl status message to an NSString. NSString *description = [NSString stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; @@ -81,8 +81,8 @@ using absl::StatusCode; MPPTasksErrorCode errorCode = genericErrorCode; - /** Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits - * absl::StatusCode::kOk. */ + // Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits + // absl::StatusCode::kOk. switch (status.code()) { case StatusCode::kCancelled: errorCode = MPPTasksErrorCodeCancelledError; From f7fc8a6eca14b2c93fbe7a8c1c5162a1f9d59223 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:05:29 +0530 Subject: [PATCH 391/469] Updated method names in tests --- .../ios/test/text/text_classifier/MPPTextClassifierTests.m | 7 +++---- .../test/text/text_classifier/TextClassifierTests.swift | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index 3e2fe4bef..a8e541014 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -43,7 +43,6 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; #define AssertTextClassifierResultHasOneHead(textClassifierResult) \ XCTAssertNotNil(textClassifierResult); \ - \ XCTAssertNotNil(textClassifierResult.classificationResult); \ XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); @@ -156,7 +155,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; AssertEqualErrors(error, expectedError); } -- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList { +- (void)testCreateTextClassifierFailsWithBothAllowlistAndDenylist { MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; options.categoryAllowlist = @[ @"positive" ]; @@ -233,7 +232,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; expectedBertResultCategoriesForEdgeCaseTests]]; } -- (void)testClassifyWithCategoryAllowListSucceeds { +- (void)testClassifyWithCategoryAllowlistSucceeds { MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; options.categoryAllowlist = @[ @"negative" ]; @@ -250,7 +249,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; expectedBertResultCategoriesForEdgeCaseTests]]; } -- (void)testClassifyWithCategoryDenyListSucceeds { +- (void)testClassifyWithCategoryDenylistSucceeds { MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; options.categoryDenylist = @[ @"positive" ]; diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift index d2d433c22..01b5748cf 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -147,7 +147,7 @@ class TextClassifierTests: XCTestCase { """) } - func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws { + func testCreateTextClassifierWithCategoryAllowlistAndDenylistFails() throws { let textClassifierOptions = try XCTUnwrap( From a0b3e620e4d024259bd2637198dd3141767f12d9 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:12:27 +0530 Subject: [PATCH 392/469] Removed unused methods --- .../test/text/text_classifier/MPPTextClassifierTests.m | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index a8e541014..5c0964e68 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -52,14 +52,6 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; @implementation MPPTextClassifierTests -- (void)setUp { -} - -- (void)tearDown { - // Put teardown code here. This method is called after the invocation of each test method in the - // class. -} - + (NSArray *)expectedBertResultCategoriesForNegativeText { return @[ [[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil], From cf945d3aebc0b705117946cddd583ae1066ef97b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:59:51 +0530 Subject: [PATCH 393/469] Removed unused variable --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 538023df6..f3d9ecc79 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -77,9 +77,7 @@ using absl::StatusCode; stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; - MPPTasksErrorCode genericErrorCode = MPPTasksErrorCodeUnknownError; - - MPPTasksErrorCode errorCode = genericErrorCode; + MPPTasksErrorCode errorCode = MPPTasksErrorCodeUnknownError; // Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits // absl::StatusCode::kOk. @@ -133,7 +131,6 @@ using absl::StatusCode; errorCode = MPPTasksErrorCodeUnauthenticatedError; break; default: - errorCode = genericErrorCode; break; } From 67735a6fd30518bb68843a140841547540b0ee61 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 14:01:10 +0530 Subject: [PATCH 394/469] Added category indices in iOS failure description --- .../text_classifier/MPPTextClassifierTests.m | 17 ++++---- .../text_classifier/TextClassifierTests.swift | 40 ++++++++++++++----- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index 5c0964e68..ebeaf863f 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -32,13 +32,16 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ NSNotFound) -#define AssertEqualCategoryArrays(categories, expectedCategories) \ - XCTAssertEqual(categories.count, expectedCategories.count); \ - for (int i = 0; i < categories.count; i++) { \ - XCTAssertEqual(categories[i].index, expectedCategories[i].index); \ - XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \ - XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \ - XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \ +#define AssertEqualCategoryArrays(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index, @"index i = %d", i); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6, \ + @"index i = %d", i); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName, \ + @"index i = %d", i); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName, \ + @"index i = %d", i); \ } #define AssertTextClassifierResultHasOneHead(textClassifierResult) \ diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift index 01b5748cf..186887778 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -60,33 +60,55 @@ class TextClassifierTests: XCTestCase { func assertCategoriesAreEqual( category: ResultCategory, - expectedCategory: ResultCategory) { + expectedCategory: ResultCategory, + indexInCategoryList: Int) { XCTAssertEqual( category.index, - expectedCategory.index) + expectedCategory.index, + String( + format: """ + category[%d].index and expectedCategory[%d].index are not equal. + """, indexInCategoryList)) XCTAssertEqual( category.score, expectedCategory.score, - accuracy:1e-6) + accuracy:1e-6, + String( + format: """ + category[%d].score and expectedCategory[%d].score are not equal. + """, indexInCategoryList)) XCTAssertEqual( category.categoryName, - expectedCategory.categoryName) + expectedCategory.categoryName, + String( + format: """ + category[%d].categoryName and expectedCategory[%d].categoryName are \ + not equal. + """, indexInCategoryList)) XCTAssertEqual( category.displayName, - expectedCategory.displayName) + expectedCategory.displayName, + String( + format: """ + category[%d].displayName and expectedCategory[%d].displayName are \ + not equal. + """, indexInCategoryList)) } func assertEqualCategoryArrays( categoryArray: [ResultCategory], expectedCategoryArray:[ResultCategory]) { - XCTAssertEqual(categoryArray.count, expectedCategoryArray.count) + XCTAssertEqual( + categoryArray.count, + expectedCategoryArray.count) - for (category, expectedCategory) in - zip(categoryArray, expectedCategoryArray) { + for (index, (category, expectedCategory)) in + zip(categoryArray, expectedCategoryArray).enumerated() { assertCategoriesAreEqual( category:category, - expectedCategory:expectedCategory) + expectedCategory:expectedCategory, + indexInCategoryList:index) } } From ffd8486d0dc045af18c6ab0e1c7bf732e5a9f3ca Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 16 Jan 2023 08:35:56 -0800 Subject: [PATCH 395/469] Add a stub WriteProfile method to GraphProfilerStub. PiperOrigin-RevId: 502388455 --- mediapipe/framework/profiler/graph_profiler_stub.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/framework/profiler/graph_profiler_stub.h b/mediapipe/framework/profiler/graph_profiler_stub.h index 12a024fe8..72d5d7275 100644 --- a/mediapipe/framework/profiler/graph_profiler_stub.h +++ b/mediapipe/framework/profiler/graph_profiler_stub.h @@ -93,6 +93,7 @@ class GraphProfilerStub { PopulateGraphConfig populate_config = PopulateGraphConfig::kNo) { return absl::OkStatus(); } + inline absl::Status WriteProfile() { return absl::OkStatus(); } inline void Pause() {} inline void Resume() {} inline void Reset() {} From c1f5920ecf3beed2457d9df9ba0bdb7cd7e5a47c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 16 Jan 2023 12:57:44 -0800 Subject: [PATCH 396/469] Add web performance tracing to the MEDIAPIPE_PROFILING repertoire This records the MEDIAPIPE_PROFILING tracing annotations to the browser's trace using the user timing API. See https://developer.mozilla.org/en-US/docs/Web/API/User_Timing_API To enable, build with --define MEDIAPIPE_WEB_PROFILING=1 --define DRISHTI_PROFILING=1 PiperOrigin-RevId: 502422030 --- mediapipe/framework/profiler/BUILD | 20 ++++++ .../profiler/web_performance_profiling.h | 68 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 mediapipe/framework/profiler/web_performance_profiling.h diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 7ebfd3b8c..6184ed45b 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -127,6 +127,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/tool:name_util", + ":web_performance_profiling", ] + select({ "//conditions:default": [], }) + select({ @@ -275,6 +276,25 @@ cc_test( ], ) +config_setting( + name = "mediapipe_web_profiling_enabled", + values = { + "define": "MEDIAPIPE_WEB_PROFILING=1", + }, + visibility = ["//visibility:private"], +) + +cc_library( + name = "web_performance_profiling", + hdrs = ["web_performance_profiling.h"], + defines = select({ + ":mediapipe_web_profiling_enabled": ["MEDIAPIPE_WEB_PROFILING_ENABLED"], + "//conditions:default": [], + }), + visibility = ["//mediapipe:__subpackages__"], + deps = ["@com_google_absl//absl/strings"], +) + cc_library( name = "profiler_resource_util", srcs = ["profiler_resource_util_common.cc"] + select({ diff --git a/mediapipe/framework/profiler/web_performance_profiling.h b/mediapipe/framework/profiler/web_performance_profiling.h new file mode 100644 index 000000000..47b76fe88 --- /dev/null +++ b/mediapipe/framework/profiler/web_performance_profiling.h @@ -0,0 +1,68 @@ +#ifndef MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ +#define MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ + +#if MEDIAPIPE_WEB_PROFILING_ENABLED && __EMSCRIPTEN__ +#include + +#include "absl/strings/str_cat.h" + +// This records MediaPipe profiling events in the browser's performance trace. +// To use, build with: +// --define MEDIAPIPE_PROFILING=1 --define MEDIAPIPE_WEB_PROFILING=1 + +namespace mediapipe { + +class WepPerformanceTraceScope { + public: + explicit WepPerformanceTraceScope(TraceEvent::EventType event_type, + const char* event_type_str, + CalculatorContext* cc) + : event_type_str_(event_type_str), cc_(cc) { + const auto& calculator_name = cc->NodeName(); + std::string start_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_start"); + std::string timestamp_str = cc->InputTimestamp().DebugString(); + EM_ASM( + { + const startName = UTF8ToString($0); + const timestamp = UTF8ToString($1); + performance.mark(startName, {mp_timestamp : timestamp}); + }, + start_name.c_str(), timestamp_str.c_str()); + } + + ~WepPerformanceTraceScope() { + const auto& calculator_name = cc_->NodeName(); + std::string start_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_start"); + std::string end_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_end"); + std::string measure_name = + absl::StrCat(calculator_name, "::", event_type_str_); + EM_ASM( + { + const startName = UTF8ToString($0); + const endName = UTF8ToString($1); + const measureName = UTF8ToString($2); + performance.mark(endName); + performance.measure(measureName, startName, endName); + }, + start_name.c_str(), end_name.c_str(), measure_name.c_str()); + } + + private: + const char* event_type_str_; + CalculatorContext* cc_; +}; + +} // namespace mediapipe + +#define MEDIAPIPE_WEB_PERFORMANCE_SCOPE(event_type, calculator_context) \ + mediapipe::WepPerformanceTraceScope web_trace_scope( \ + mediapipe::TraceEvent::event_type, #event_type, calculator_context) + +#else +#define MEDIAPIPE_WEB_PERFORMANCE_SCOPE(event_type, calculator_context) +#endif // MEDIAPIPE_WEB_PROFILING_ENABLED && __EMSCRIPTEN__ + +#endif // MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ From 7974171c3d0364a1bd79b6dc615b60ff57b175e7 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 17 Jan 2023 09:04:54 -0800 Subject: [PATCH 397/469] Merge `classificationResultList()` and `classificationResult()` to be `classificationResults()`, and similar for `embeddingResults()`. PiperOrigin-RevId: 502601043 --- .../AudioClassifierResult.java | 27 +++++++--------- .../audioembedder/AudioEmbedderResult.java | 31 +++++++++---------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java index 3102aa8cd..258e5725b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java @@ -20,7 +20,6 @@ import com.google.mediapipe.tasks.components.containers.proto.ClassificationsPro import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.List; -import java.util.Optional; /** Represents the classification results generated by {@link AudioClassifier}. */ @AutoValue @@ -40,8 +39,7 @@ public abstract class AudioClassifierResult implements TaskResult { for (ClassificationsProto.ClassificationResult proto : protoList) { classificationResultList.add(ClassificationResult.createFromProto(proto)); } - return new AutoValue_AudioClassifierResult( - Optional.of(classificationResultList), Optional.empty(), timestampMs); + return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs); } /** @@ -53,23 +51,22 @@ public abstract class AudioClassifierResult implements TaskResult { */ static AudioClassifierResult createFromProto( ClassificationsProto.ClassificationResult proto, long timestampMs) { - return new AutoValue_AudioClassifierResult( - Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs); + List classificationResultList = new ArrayList<>(); + classificationResultList.add(ClassificationResult.createFromProto(proto)); + return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs); } /** * A list of of timestamped {@link ClassificationResult} objects, each contains one set of results - * per classifier head. The list represents the audio classification result of an audio clip, and - * is only available when running with the audio clips mode. + * per classifier head. + * + *

In the "audio stream" mode, the list only contains one element, representing the + * classification result of the audio block that starts at {@link + * ClassificationResult.timestampMs} in the audio stream. Otherwise, in the "audio clips" mode, + * the list may include multiple {@link ClassificationResult} objects, each classifying an + * interval of the entire audio clip that starts at {@link ClassificationResult.timestampMs}. */ - public abstract Optional> classificationResultList(); - - /** - * Contains one set of results per classifier head. A {@link ClassificationResult} usually - * represents one audio classification result in an audio stream, and s only available when - * running with the audio stream mode. - */ - public abstract Optional classificationResult(); + public abstract List classificationResults(); @Override public abstract long timestampMs(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java index a986048f0..0cfd2297c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java @@ -20,7 +20,6 @@ import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.List; -import java.util.Optional; /** Represents the embedding results generated by {@link AudioEmbedder}. */ @AutoValue @@ -35,12 +34,11 @@ public abstract class AudioEmbedderResult implements TaskResult { */ static AudioEmbedderResult createFromProtoList( List protoList, long timestampMs) { - List classificationResultList = new ArrayList<>(); + List embeddingResultList = new ArrayList<>(); for (EmbeddingsProto.EmbeddingResult proto : protoList) { - classificationResultList.add(EmbeddingResult.createFromProto(proto)); + embeddingResultList.add(EmbeddingResult.createFromProto(proto)); } - return new AutoValue_AudioEmbedderResult( - Optional.of(classificationResultList), Optional.empty(), timestampMs); + return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs); } /** @@ -52,23 +50,22 @@ public abstract class AudioEmbedderResult implements TaskResult { */ static AudioEmbedderResult createFromProto( EmbeddingsProto.EmbeddingResult proto, long timestampMs) { - return new AutoValue_AudioEmbedderResult( - Optional.empty(), Optional.of(EmbeddingResult.createFromProto(proto)), timestampMs); + List embeddingResultList = new ArrayList<>(); + embeddingResultList.add(EmbeddingResult.createFromProto(proto)); + return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs); } /** * A list of of timpstamped {@link EmbeddingResult} objects, each contains one set of results per - * embedder head. The list represents the audio embedding result of an audio clip, and is only - * available when running with the audio clips mode. + * embedder head. + * + *

In the "audio stream" mode, the list only contains one element, representing the embedding + * result of the audio block that starts at {@link EmbeddingResult.timestampMs} in the audio + * stream. Otherwise, in the "audio clips" mode, the list may include multiple {@link + * EmbeddingResult} objects, each contains the embedding of an interval of the entire audio clip + * that starts at {@link EmbeddingResult.timestampMs}. */ - public abstract Optional> embeddingResultList(); - - /** - * Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents - * one audio embedding result in an audio stream, and is only available when running with the - * audio stream mode. - */ - public abstract Optional embeddingResult(); + public abstract List embeddingResults(); @Override public abstract long timestampMs(); From 7a4b450c501ca14f2a34dc6d2810361d7424e03d Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 17 Jan 2023 10:51:04 -0800 Subject: [PATCH 398/469] Resolve the error "call to 'abs' is ambiguous". PiperOrigin-RevId: 502630518 --- mediapipe/tasks/cc/components/containers/rect.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 551d91588..72c7a8acb 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include #include namespace mediapipe::tasks::components::containers { @@ -48,10 +49,10 @@ struct RectF { }; inline bool operator==(const RectF& lhs, const RectF& rhs) { - return abs(lhs.left - rhs.left) < kRectFTolerance && - abs(lhs.top - rhs.top) < kRectFTolerance && - abs(lhs.right - rhs.right) < kRectFTolerance && - abs(lhs.bottom - rhs.bottom) < kRectFTolerance; + return std::fabs(lhs.left - rhs.left) < kRectFTolerance && + std::fabs(lhs.top - rhs.top) < kRectFTolerance && + std::fabs(lhs.right - rhs.right) < kRectFTolerance && + std::fabs(lhs.bottom - rhs.bottom) < kRectFTolerance; } RectF ToRectF(const Rect& rect, int image_height, int image_width); From 088249eb3697865dcd05c19dfb9065ddcf498d7e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 11:56:05 -0800 Subject: [PATCH 399/469] Export all input and output types PiperOrigin-RevId: 502649430 --- mediapipe/tasks/web/audio/index.ts | 14 +++----------- mediapipe/tasks/web/text/index.ts | 14 +++----------- mediapipe/tasks/web/vision/index.ts | 30 ++++++----------------------- 3 files changed, 12 insertions(+), 46 deletions(-) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 44fa7eb25..dbad8c617 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,14 +14,6 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/audio_classifier/audio_classifier'; -import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; -import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const AudioClassifier = AudioClassifierImpl; -const AudioEmbedder = AudioEmbedderImpl; -const FilesetResolver = FilesetResolverImpl; - -export {AudioClassifier, AudioEmbedder, FilesetResolver}; +export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index 2c9e6fead..f32c16c36 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,14 +14,6 @@ * limitations under the License. */ -import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; -import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; -import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const TextClassifier = TextClassifierImpl; -const TextEmbedder = TextEmbedderImpl; - -export {FilesetResolver, TextClassifier, TextEmbedder}; +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index e13f8183f..2ba6ca812 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,27 +14,9 @@ * limitations under the License. */ -import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; -import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; -import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; -import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const GestureRecognizer = GestureRecognizerImpl; -const HandLandmarker = HandLandmarkerImpl; -const ImageClassifier = ImageClassifierImpl; -const ImageEmbedder = ImageEmbedderImpl; -const ObjectDetector = ObjectDetectorImpl; - -export { - FilesetResolver, - GestureRecognizer, - HandLandmarker, - ImageClassifier, - ImageEmbedder, - ObjectDetector -}; +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/object_detector/object_detector'; From 7894c92ab7edded4810665958cc904b4e768e29a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 17 Jan 2023 15:45:15 -0800 Subject: [PATCH 400/469] Internal change PiperOrigin-RevId: 502709070 --- mediapipe/util/log_fatal_to_breakpad.cc | 50 +++++++++++++++++++++++++ mediapipe/util/log_fatal_to_breakpad.h | 15 ++++++++ 2 files changed, 65 insertions(+) create mode 100644 mediapipe/util/log_fatal_to_breakpad.cc create mode 100644 mediapipe/util/log_fatal_to_breakpad.h diff --git a/mediapipe/util/log_fatal_to_breakpad.cc b/mediapipe/util/log_fatal_to_breakpad.cc new file mode 100644 index 000000000..45087f2e3 --- /dev/null +++ b/mediapipe/util/log_fatal_to_breakpad.cc @@ -0,0 +1,50 @@ +#include "mediapipe/util/log_fatal_to_breakpad.h" + +#import + +#include "absl/log/log.h" +#include "absl/log/log_sink.h" +#include "absl/log/log_sink_registry.h" +#import "googlemac/iPhone/Shared/GoogleIOSBreakpad/Classes/GoogleBreakpadController.h" + +namespace mediapipe { +namespace { +NSString* MakeNSString(absl::string_view str) { + return [[NSString alloc] initWithBytes:str.data() + length:str.length() + encoding:NSUTF8StringEncoding]; +} +} // namespace + +static NSString* const kFatalLogMessageKey = @"fatal_log_message"; + +class BreakpadFatalLogSink : public absl::LogSink { + public: + BreakpadFatalLogSink() + : breakpad_controller_([GoogleBreakpadController sharedInstance]) {} + void Send(const absl::LogEntry& entry) override { + if (entry.log_severity() != absl::LogSeverity::kFatal) return; + __block NSString* message = MakeNSString(entry.text_message_with_prefix()); + [breakpad_controller_ withBreakpadRef:^(BreakpadRef breakpad) { + // NOTE: This block runs on Breakpad's background queue. + if (!breakpad) return; + BreakpadAddUploadParameter(breakpad, kFatalLogMessageKey, message); + }]; + } + + private: + GoogleBreakpadController* breakpad_controller_; +}; + +absl::LogSink* GetBreakpadFatalLogSink() { + static BreakpadFatalLogSink sink; + return &sink; +} + +// This log sink is automatically enabled when including this library. +static const auto kRegisterLogSink = [] { + absl::AddLogSink(GetBreakpadFatalLogSink()); + return true; +}(); + +} // namespace mediapipe diff --git a/mediapipe/util/log_fatal_to_breakpad.h b/mediapipe/util/log_fatal_to_breakpad.h new file mode 100644 index 000000000..1712a9af8 --- /dev/null +++ b/mediapipe/util/log_fatal_to_breakpad.h @@ -0,0 +1,15 @@ +#ifndef MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ +#define MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ + +#include "absl/log/log_sink.h" + +namespace mediapipe { + +// Returns a singleton instance of a log sink that sends FATAL log messages to +// Breakpad. This log sink is enabled by default when this library is included +// in your binary. +absl::LogSink* GetBreakpadFatalLogSink(); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ From 0b97c6e67d316d023ee4ea61366a6c9d886f7ac4 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 15:45:28 -0800 Subject: [PATCH 401/469] Update the MP Wasm builds to latest version. PiperOrigin-RevId: 502709126 --- third_party/wasm_files.bzl | 48 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 504f8567a..017d84466 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,72 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "42d2d0ade6e2e8b81425b23686be93eb1423b7777f043eb8f18ad671e2ca803f", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1669173769507080"], + sha256 = "d4d205d08e3e1b09662a9a358d0107e8a8023827ba9b6982a3777bb6c040f936", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1673996821002628"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "20200ee9b0866d5176f633a9b375e8a44e53204c01ea2e159e2f9245afb00e80", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1669173772528997"], + sha256 = "1b2ffe82b0a25d20188237a724a7cad68d068818a7738f91c69c782314f55965", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1673996823772372"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", - sha256 = "11bbf73d48723b19a5a6a13ec296ecdb2aa178cdc3db9d7bc54265a7d4b94c6a", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1669173774625527"], + sha256 = "1f367c2d667628b178251aec7fd464327351570edac4549450b11fb82f5f0fd4", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1673996826132845"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", - sha256 = "d4528972219033996a83a62798952b6ee8b6b396bcffd96fd5bda5458d57d3a3", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1669173777474822"], + sha256 = "35c6ad888c06025dba1f9c8edb70e6c7be7e94e45dc2c0236a2fcfe61991dc44", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1673996828935550"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "29e72e177122f92bda6a3ecd463ebacf30b920559b06c97068112a22eeea4d0e", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1669173779706893"], + sha256 = "68c0134e0b3cb986c3526cd645f74cc5a1f6ab19292276ca7d3558b89801e205", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1673996831356232"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "84e5f5ac70f7718baeaa09a89b155abbea67386e7d50663301b3af7ef0941e74", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1669173782728605"], + sha256 = "df82bb192ea852dc1bcc8f9f28fbd8c3d6b219dc4fec2b2a92451678d98ee1f0", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1673996834657078"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", - sha256 = "36f247673124e32535f217265b96508c1badee8fe2458c11c1efa95b6bec5daa", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1669173785027190"], + sha256 = "de1a4aabefb2e42ae4fee68b7e762e328623a163257a7ddc72365fc2502bd090", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1673996837104551"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", - sha256 = "cc74d90a8aaf6d006ec24048cc80c33f96baeeb0075a6c6739f30d41da54e450", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1669173787903754"], + sha256 = "828dd1e73fa9478a97a62539117f92b813833ab35d37a986c466df15a8cfdc7b", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1673996840120504"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "c3451423186766b08008e07ef6d52f628fcc0aca75beedd9bb4d87d380f29edd", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1669173790070986"], + sha256 = "c146b68523c256d41132230e811fc224dafb6a0bce6fc318c29dad37dfac06de", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1673996842448396"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "d1e8ad748913e3f190bfd3f72e0e8a4a308f78b918d54c79cec60a2cf30a49f0", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1669173792993881"], + sha256 = "8dbccaaf944ef1251cf78190450ab7074abea233e18ebb37d2c2ce0f18d14a0c", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1673996845499070"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", - sha256 = "e5f1b5e8264ff9a90371653cb0fdbf9ce3b30b712acbd72068af18ebca2293ac", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1669173794969702"], + sha256 = "705f9e3c2c62d12903ea2cadc22d2c328bc890f96fffc47b51f989471196ecea", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1673996847915731"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", - sha256 = "24351fe580e88f2065b1978b8b3c0f3ad7b90f1c95805aafa07971ce422b5854", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1669173797596874"], + sha256 = "c7ff6a7d8dc22380e2e8457a15a51b6bc1e70c6262fecca25825f54ecc593d1f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1673996850980344"], ) From d5e60eb658c231424209d5274d9edb28bebca367 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 20:51:40 -0800 Subject: [PATCH 402/469] Internal change PiperOrigin-RevId: 502764352 --- mediapipe/tasks/web/audio/BUILD | 19 ++++++++++++++----- mediapipe/tasks/web/audio/types.ts | 19 +++++++++++++++++++ mediapipe/tasks/web/text/BUILD | 19 ++++++++++++++----- mediapipe/tasks/web/text/types.ts | 19 +++++++++++++++++++ mediapipe/tasks/web/vision/BUILD | 25 +++++++++++++++++-------- mediapipe/tasks/web/vision/types.ts | 22 ++++++++++++++++++++++ 6 files changed, 105 insertions(+), 18 deletions(-) create mode 100644 mediapipe/tasks/web/audio/types.ts create mode 100644 mediapipe/tasks/web/text/types.ts create mode 100644 mediapipe/tasks/web/vision/types.ts diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 7e05263fe..409836800 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -10,15 +10,24 @@ load( package(default_visibility = ["//mediapipe/tasks:internal"]) +AUDIO_LIBS = [ + "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", + "//mediapipe/tasks/web/core:fileset_resolver", +] + mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], visibility = ["//visibility:public"], - deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", - "//mediapipe/tasks/web/audio/audio_embedder", - "//mediapipe/tasks/web/core:fileset_resolver", - ], + deps = AUDIO_LIBS, +) + +mediapipe_ts_library( + name = "audio_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = AUDIO_LIBS, ) mediapipe_files(srcs = [ diff --git a/mediapipe/tasks/web/audio/types.ts b/mediapipe/tasks/web/audio/types.ts new file mode 100644 index 000000000..19073b708 --- /dev/null +++ b/mediapipe/tasks/web/audio/types.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 6f019aca1..ebe3403b2 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -17,15 +17,24 @@ mediapipe_files(srcs = [ "wasm/text_wasm_nosimd_internal.wasm", ]) +TEXT_LIBS = [ + "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", +] + mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], visibility = ["//visibility:public"], - deps = [ - "//mediapipe/tasks/web/core:fileset_resolver", - "//mediapipe/tasks/web/text/text_classifier", - "//mediapipe/tasks/web/text/text_embedder", - ], + deps = TEXT_LIBS, +) + +mediapipe_ts_library( + name = "text_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = TEXT_LIBS, ) rollup_bundle( diff --git a/mediapipe/tasks/web/text/types.ts b/mediapipe/tasks/web/text/types.ts new file mode 100644 index 000000000..bd01b1c6f --- /dev/null +++ b/mediapipe/tasks/web/text/types.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 76b0c084e..8ba9c85b3 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -17,18 +17,27 @@ mediapipe_files(srcs = [ "wasm/vision_wasm_nosimd_internal.wasm", ]) +VISION_LIBS = [ + "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", + "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/object_detector", +] + mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], visibility = ["//visibility:public"], - deps = [ - "//mediapipe/tasks/web/core:fileset_resolver", - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/hand_landmarker", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_embedder", - "//mediapipe/tasks/web/vision/object_detector", - ], + deps = VISION_LIBS, +) + +mediapipe_ts_library( + name = "vision_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = VISION_LIBS, ) rollup_bundle( diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts new file mode 100644 index 000000000..dd1f58294 --- /dev/null +++ b/mediapipe/tasks/web/vision/types.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/object_detector/object_detector'; From e484bd681e03223a09619a6088dbb8b1a6c7557e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 20:53:07 -0800 Subject: [PATCH 403/469] Export all input and output types PiperOrigin-RevId: 502764544 --- mediapipe/tasks/web/audio/index.ts | 14 +++++++++++--- mediapipe/tasks/web/text/index.ts | 14 +++++++++++--- mediapipe/tasks/web/vision/index.ts | 30 +++++++++++++++++++++++------ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index dbad8c617..44fa7eb25 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; -export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; + +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index f32c16c36..2c9e6fead 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/core/fileset_resolver'; -export * from '../../../tasks/web/text/text_classifier/text_classifier'; -export * from '../../../tasks/web/text/text_embedder/text_embedder'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; +import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 2ba6ca812..e13f8183f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,9 +14,27 @@ * limitations under the License. */ -export * from '../../../tasks/web/core/fileset_resolver'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -export * from '../../../tasks/web/vision/image_classifier/image_classifier'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder'; -export * from '../../../tasks/web/vision/object_detector/object_detector'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; +import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; +import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ObjectDetector = ObjectDetectorImpl; + +export { + FilesetResolver, + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ObjectDetector +}; From 3688757d1706a5252de8196dfa56947dc0164671 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 18 Jan 2023 07:26:38 -0800 Subject: [PATCH 404/469] Fix `load_metadata_buffer` for empty metadata PiperOrigin-RevId: 502870428 --- mediapipe/tasks/python/metadata/metadata.py | 2 ++ .../python/test/metadata/metadata_test.py | 26 +++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py index 10a0b9b66..2327ebbdf 100644 --- a/mediapipe/tasks/python/metadata/metadata.py +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -860,6 +860,8 @@ def get_metadata_buffer(model_buf): if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: buffer_index = meta.Buffer() metadata = tflite_model.Buffers(buffer_index) + if metadata.DataLength() == 0: + continue return metadata.DataAsNumpy().tobytes() return None diff --git a/mediapipe/tasks/python/test/metadata/metadata_test.py b/mediapipe/tasks/python/test/metadata/metadata_test.py index bed9c2833..d892f1b61 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_test.py @@ -550,7 +550,7 @@ class MetadataPopulatorTest(MetadataTest): ("The number of output tensors (1) should match the number of " "output tensor metadata (0)"), str(error.exception)) - def testLoadMetadataAndAssociatedFilesShouldSucceeds(self): + def testLoadMetadataAndAssociatedFilesShouldSucceed(self): # Create a src model with metadata and two associated files. src_model_buf = self._create_model_buf() populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) @@ -566,7 +566,7 @@ class MetadataPopulatorTest(MetadataTest): populator_src.get_model_buffer()) populator_dst.populate() - # Tests if the metadata and associated files are populated correctly. + # Test if the metadata and associated files are populated correctly. dst_model_file = self.create_tempfile().full_path with open(dst_model_file, "wb") as f: f.write(populator_dst.get_model_buffer()) @@ -575,6 +575,28 @@ class MetadataPopulatorTest(MetadataTest): recorded_files = populator_dst.get_recorded_associated_file_list() self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + def testLoadMetadataAndAssociatedFilesShouldSucceedOnEmptyMetadata(self): + # When the user hasn't specified the metadata, but only the associated + # files, an empty metadata buffer is created. Previously, it caused an + # exception when reading. + + # Create a source model with two associated files but no metadata. + src_model_buf = self._create_model_buf() + populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) + populator_src.load_associated_files([self._file1, self._file2]) + populator_src.populate() + + # Create a model to be populated with the files from `src_model_buf`. + dst_model_buf = self._create_model_buf() + populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf) + populator_dst.load_metadata_and_associated_files( + populator_src.get_model_buffer()) + populator_dst.populate() + + # Test if the metadata and associated files are populated correctly. + packed_files = populator_dst.get_packed_associated_file_list() + self.assertEqual(set(packed_files), set(self.expected_recorded_files)) + @parameterized.named_parameters( { "testcase_name": "InputTensorWithBert", From 29484702cef7908881262579382f4f4f8055170f Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 18 Jan 2023 08:00:48 -0800 Subject: [PATCH 405/469] Add `process_timestamp_bounds` into RectToRenderScaleCalculatorOptions. PiperOrigin-RevId: 502877541 --- mediapipe/calculators/util/rect_to_render_scale_calculator.cc | 4 +++- .../calculators/util/rect_to_render_scale_calculator.proto | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index 6ff6b3d51..85ed1db72 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -80,7 +80,9 @@ absl::Status RectToRenderScaleCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormRectTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs().Tag(kRenderScaleTag).Set(); - + cc->SetProcessTimestampBounds( + cc->Options() + .process_timestamp_bounds()); return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto index dda6e2c9c..377b12412 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto @@ -29,4 +29,8 @@ message RectToRenderScaleCalculatorOptions { // when actual object size on the image will be `B`, than all RenderData // primitives will be scaled with factor `B/A`. optional float multiplier = 1 [default = 0.01]; + + // When true, Process is called for every new timestamp bound, with or without + // new packets. + optional bool process_timestamp_bounds = 2 [default = false]; } From 5687d19dec64dbea7ec70337ea67dd015d366d77 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 18 Jan 2023 09:06:48 -0800 Subject: [PATCH 406/469] Tensor: remove unused and unimplemented SetPreferredStorageType methods. PiperOrigin-RevId: 502893019 --- mediapipe/framework/formats/tensor.h | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 4a952ae09..fe0be31d1 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -370,13 +370,6 @@ class Tensor { bool ready_as_opengl_texture_2d() const { return valid_ & kValidOpenGlTexture2d; } - // Sets the type of underlying resource that is going to be allocated. - enum class StorageType { - kDefault, - kAhwb, - }; - static void SetPreferredStorageType(StorageType type); - static StorageType GetPreferredStorageType(); private: void Move(Tensor*); From e56fa8f258dbc32458e595ecca8043e7a8aeb893 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 18 Jan 2023 10:59:56 -0800 Subject: [PATCH 407/469] Source/SideSource -> Stream/SidePacket PiperOrigin-RevId: 502923931 --- mediapipe/framework/api2/builder_test.cc | 50 ++++++++++++------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index b01c2b759..08f4f0ca1 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -53,20 +53,20 @@ TEST(BuilderTest, BuildGraph) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -TEST(BuilderTest, CopyableSource) { +TEST(BuilderTest, CopyableStream) { Graph graph; - Source a = graph.In("A").SetName("a").Cast(); - Source b = graph.In("B").SetName("b").Cast(); - SideSource side_a = + Stream a = graph.In("A").SetName("a").Cast(); + Stream b = graph.In("B").SetName("b").Cast(); + SidePacket side_a = graph.SideIn("SIDE_A").SetName("side_a").Cast(); - SideSource side_b = + SidePacket side_b = graph.SideIn("SIDE_B").SetName("side_b").Cast(); Destination out = graph.Out("OUT").Cast(); SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); - Source input = a; + Stream input = a; input = b; - SideSource side_input = side_b; + SidePacket side_input = side_b; side_input = side_a; input >> out; @@ -87,23 +87,23 @@ TEST(BuilderTest, CopyableSource) { TEST(BuilderTest, BuildGraphWithFunctions) { Graph graph; - Source base = graph.In("IN").SetName("base").Cast(); - SideSource side = graph.SideIn("SIDE").SetName("side").Cast(); + Stream base = graph.In("IN").SetName("base").Cast(); + SidePacket side = graph.SideIn("SIDE").SetName("side").Cast(); - auto foo_fn = [](Source base, SideSource side, Graph& graph) { + auto foo_fn = [](Stream base, SidePacket side, Graph& graph) { auto& foo = graph.AddNode("Foo"); base >> foo.In("BASE"); side >> foo.SideIn("SIDE"); return foo.Out("OUT")[0].Cast(); }; - Source foo_out = foo_fn(base, side, graph); + Stream foo_out = foo_fn(base, side, graph); - auto bar_fn = [](Source in, Graph& graph) { + auto bar_fn = [](Stream in, Graph& graph) { auto& bar = graph.AddNode("Bar"); in >> bar.In("IN"); return bar.Out("OUT")[0].Cast(); }; - Source bar_out = bar_fn(foo_out, graph); + Stream bar_out = bar_fn(foo_out, graph); bar_out.SetName("out") >> graph.Out("OUT"); @@ -375,26 +375,26 @@ class AnyAndSameTypeCalculator : public NodeIntf { TEST(BuilderTest, AnyAndSameTypeHandledProperly) { Graph graph; - Source any_input = graph.In("GRAPH_ANY_INPUT"); - Source int_input = graph.In("GRAPH_INT_INPUT").Cast(); + Stream any_input = graph.In("GRAPH_ANY_INPUT"); + Stream int_input = graph.In("GRAPH_INT_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - Source any_type_output = + Stream any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; any_type_output.SetName("any_type_output"); - Source same_type_output = + Stream same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; same_type_output.SetName("same_type_output"); - Source recursive_same_type_output = + Stream recursive_same_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; recursive_same_type_output.SetName("recursive_same_type_output"); - Source same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; + Stream same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; same_int_output.SetName("same_int_output"); - Source recursive_same_int_type_output = + Stream recursive_same_int_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; recursive_same_int_type_output.SetName("recursive_same_int_type_output"); @@ -418,12 +418,12 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { TEST(BuilderTest, AnyTypeCanBeCast) { Graph graph; - Source any_input = + Stream any_input = graph.In("GRAPH_ANY_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; - Source any_type_output = + Stream any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput] .SetName("any_type_output") .Cast(); @@ -462,7 +462,7 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) { TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { Graph graph; MultiSource any_multi_input = graph.In("ANY_INPUT"); - Source any_input = any_multi_input; + Stream any_input = any_multi_input; MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); Destination any_output = any_multi_output; any_input >> any_output; @@ -477,8 +477,8 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { Graph graph; - Source int_input = graph.In("INT_INPUT").Cast(); - Source any_input = graph.In("ANY_OUTPUT"); + Stream int_input = graph.In("INT_INPUT").Cast(); + Stream any_input = graph.In("ANY_OUTPUT"); Destination int_output = graph.Out("INT_OUTPUT").Cast(); Destination any_output = graph.Out("ANY_OUTPUT"); int_input >> int_output; From 66634bbef88c390ccd5c85774b839a81ea73240f Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Wed, 18 Jan 2023 16:36:09 -0800 Subject: [PATCH 408/469] Internal change PiperOrigin-RevId: 503011674 --- mediapipe/framework/tool/switch/BUILD | 34 +++++++ .../framework/tool/switch/packet_processor.h | 88 +++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 mediapipe/framework/tool/switch/BUILD create mode 100644 mediapipe/framework/tool/switch/packet_processor.h diff --git a/mediapipe/framework/tool/switch/BUILD b/mediapipe/framework/tool/switch/BUILD new file mode 100644 index 000000000..62f9095ef --- /dev/null +++ b/mediapipe/framework/tool/switch/BUILD @@ -0,0 +1,34 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +cc_library( + name = "packet_processor", + hdrs = ["packet_processor.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:packet", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], +) diff --git a/mediapipe/framework/tool/switch/packet_processor.h b/mediapipe/framework/tool/switch/packet_processor.h new file mode 100644 index 000000000..1789a46c5 --- /dev/null +++ b/mediapipe/framework/tool/switch/packet_processor.h @@ -0,0 +1,88 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ + +#include + +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// PacketConsumer accepts several tagged streams of packets. +class PacketConsumer { + public: + virtual ~PacketConsumer() = default; + + // Accepts a tagged input packet. + virtual absl::Status AddPacket(CollectionItemId id, Packet packet) = 0; + + // Returns the id for each input tag. + virtual std::shared_ptr InputTags() = 0; +}; + +// PacketConsumer delivers several tagged streams of packets. +class PacketProducer { + public: + virtual ~PacketProducer() = default; + + // Connects a consumer to recieve packets from this producer. + virtual void SetConsumer(PacketConsumer* consumer) = 0; +}; + +// SidePacketConsumer accepts several tagged constant packets. +class SidePacketConsumer { + public: + virtual ~SidePacketConsumer() = default; + + // Accepts a tagged input side-packet. + virtual absl::Status SetSidePacket(CollectionItemId id, Packet packet) = 0; + + // Returns the id for each input side-packet tag. + virtual std::shared_ptr SideInputTags() = 0; +}; + +// SidePacketProducer deleivers several tagged constant packets. +class SidePacketProducer { + public: + virtual ~SidePacketProducer() = default; + + // Connects a consumer to recieve packets from this producer. + virtual void SetSideConsumer(SidePacketConsumer* consumer) = 0; +}; + +// PacketProcessor consumes and produces packet streams and constant packets. +class PacketProcessor : public PacketConsumer, + public PacketProducer, + public SidePacketConsumer, + public SidePacketProducer { + public: + virtual ~PacketProcessor() = default; + + // Activate this PacketProcessor. + virtual absl::Status Start() = 0; + + // Block until this PacketProcessor has no remaining work to do. + virtual absl::Status WaitUntilIdle() = 0; + + // Deactivate this PacketProcessor. + virtual absl::Status Shutdown() = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ From 97af47ebf55e910b5c2125cba2f878e396be1b14 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Wed, 18 Jan 2023 18:51:17 -0800 Subject: [PATCH 409/469] Internal change PiperOrigin-RevId: 503035081 --- mediapipe/framework/tool/switch/BUILD | 26 +++++ .../framework/tool/switch/graph_processor.cc | 110 ++++++++++++++++++ .../framework/tool/switch/graph_processor.h | 59 ++++++++++ .../framework/tool/switch/packet_processor.h | 2 +- 4 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 mediapipe/framework/tool/switch/graph_processor.cc create mode 100644 mediapipe/framework/tool/switch/graph_processor.h diff --git a/mediapipe/framework/tool/switch/BUILD b/mediapipe/framework/tool/switch/BUILD index 62f9095ef..e7a3ba741 100644 --- a/mediapipe/framework/tool/switch/BUILD +++ b/mediapipe/framework/tool/switch/BUILD @@ -32,3 +32,29 @@ cc_library( "//mediapipe/framework/port:status", ], ) + +cc_library( + name = "graph_processor", + srcs = ["graph_processor.cc"], + hdrs = ["graph_processor.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":packet_processor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:input_stream_shard", + "//mediapipe/framework:output_stream_shard", + "//mediapipe/framework:validated_graph_config", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) diff --git a/mediapipe/framework/tool/switch/graph_processor.cc b/mediapipe/framework/tool/switch/graph_processor.cc new file mode 100644 index 000000000..f35730761 --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.cc @@ -0,0 +1,110 @@ +#include "mediapipe/framework/tool/switch/graph_processor.h" + +#include "absl/synchronization/mutex.h" + +namespace mediapipe { + +// TODO: add support for input and output side packets. +absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) { + graph_config_ = graph_config; + + ASSIGN_OR_RETURN(graph_input_map_, + tool::TagMap::Create(graph_config_.input_stream())); + ASSIGN_OR_RETURN(graph_output_map_, + tool::TagMap::Create(graph_config_.output_stream())); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) { + absl::MutexLock lock(&graph_mutex_); + const std::string& stream_name = graph_input_map_->Names().at(id.value()); + return graph_->AddPacketToInputStream(stream_name, packet); +} + +std::shared_ptr GraphProcessor::InputTags() { + return graph_input_map_; +} + +absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) { + MP_RETURN_IF_ERROR(WaitUntilInitialized()); + auto it = consumer_ids_.find(id); + if (it == consumer_ids_.end()) { + return absl::NotFoundError( + absl::StrCat("Consumer stream not found: ", id.value())); + } + return consumer_->AddPacket(it->second, packet); +} + +void GraphProcessor::SetConsumer(PacketConsumer* consumer) { + absl::MutexLock lock(&graph_mutex_); + consumer_ = consumer; + auto input_map = consumer_->InputTags(); + for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) { + auto tag_index = input_map->TagAndIndexFromId(id); + auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second); + consumer_ids_[stream_id] = id; + } +} + +absl::Status GraphProcessor::ObserveGraph() { + for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId(); + ++id) { + std::string stream_name = graph_output_map_->Names().at(id.value()); + MP_RETURN_IF_ERROR(graph_->ObserveOutputStream( + stream_name, + [this, id](const Packet& packet) { return SendPacket(id, packet); }, + true)); + } + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilInitialized() { + absl::MutexLock lock(&graph_mutex_); + auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) { + return graph_ != nullptr && consumer_ != nullptr; + }; + graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized), + absl::Seconds(4)); + RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out."; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Start() { + absl::MutexLock lock(&graph_mutex_); + graph_ = std::make_unique(); + + // The graph is validated here with its specified inputs and output. + MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_)); + MP_RETURN_IF_ERROR(ObserveGraph()); + MP_RETURN_IF_ERROR(graph_->StartRun({})); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Shutdown() { + absl::MutexLock lock(&graph_mutex_); + if (!graph_) { + return absl::OkStatus(); + } + MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph_->WaitUntilDone()); + graph_ = nullptr; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilIdle() { + absl::MutexLock lock(&graph_mutex_); + return graph_->WaitUntilIdle(); +} + +// TODO +absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) { + return absl::OkStatus(); +} +// TODO +std::shared_ptr GraphProcessor::SideInputTags() { + return nullptr; +} +// TODO +void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {} + +} // namespace mediapipe diff --git a/mediapipe/framework/tool/switch/graph_processor.h b/mediapipe/framework/tool/switch/graph_processor.h new file mode 100644 index 000000000..e2220b5dc --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.h @@ -0,0 +1,59 @@ +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/switch/packet_processor.h" + +namespace mediapipe { + +// Processes MediaPipe Packets using a MediaPipe CalculatorGraph. +class GraphProcessor : public PacketProcessor { + public: + GraphProcessor() = default; + + // Configures this GraphProcessor to create a run a CalculatorGraph. + absl::Status Initialize(CalculatorGraphConfig graph_config); + + public: + // The PacketProcessor interface. + absl::Status AddPacket(CollectionItemId id, Packet packet) override; + std::shared_ptr InputTags() override; + absl::Status SetSidePacket(CollectionItemId id, Packet packet) override; + std::shared_ptr SideInputTags() override; + void SetConsumer(PacketConsumer* consumer) override; + void SetSideConsumer(SidePacketConsumer* consumer) override; + absl::Status Start() override; + absl::Status Shutdown() override; + absl::Status WaitUntilIdle() override; + + private: + // Sends a tagged output packet. + absl::Status SendPacket(CollectionItemId id, Packet packet); + + // Observes output packets from the calculator graph. + absl::Status ObserveGraph() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_); + + // Blocks until this GraphProcessor is initialized. + absl::Status WaitUntilInitialized(); + + private: + CalculatorGraphConfig graph_config_; + std::shared_ptr graph_input_map_; + std::shared_ptr graph_output_map_; + std::map consumer_ids_; + + PacketConsumer* consumer_ = nullptr; + std::map side_packets_; + std::unique_ptr graph_ ABSL_GUARDED_BY(graph_mutex_) = + nullptr; + absl::Mutex graph_mutex_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ diff --git a/mediapipe/framework/tool/switch/packet_processor.h b/mediapipe/framework/tool/switch/packet_processor.h index 1789a46c5..d97883c53 100644 --- a/mediapipe/framework/tool/switch/packet_processor.h +++ b/mediapipe/framework/tool/switch/packet_processor.h @@ -56,7 +56,7 @@ class SidePacketConsumer { virtual std::shared_ptr SideInputTags() = 0; }; -// SidePacketProducer deleivers several tagged constant packets. +// SidePacketProducer delivers several tagged constant packets. class SidePacketProducer { public: virtual ~SidePacketProducer() = default; From e2dedcbfe569d4a33ad24ac77fee51a2ed53d5b2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 18 Jan 2023 19:40:19 -0800 Subject: [PATCH 410/469] Add SQRT_HANN window type to both SpectrogramCalculator and InverseSpectrogramCalculator. PiperOrigin-RevId: 503041493 --- mediapipe/calculators/audio/spectrogram_calculator.cc | 7 +++++++ mediapipe/calculators/audio/spectrogram_calculator.proto | 1 + 2 files changed, 8 insertions(+) diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index c038c0cd7..bd4d8f3bf 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -280,6 +280,13 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_, &window); break; + case SpectrogramCalculatorOptions::SQRT_HANN: { + audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, + &window); + absl::c_transform(window, window.begin(), + [](double x) { return std::sqrt(x); }); + break; + } } // Propagate settings down to the actual Spectrogram object. diff --git a/mediapipe/calculators/audio/spectrogram_calculator.proto b/mediapipe/calculators/audio/spectrogram_calculator.proto index 8e1e18051..ddfca1d1c 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.proto +++ b/mediapipe/calculators/audio/spectrogram_calculator.proto @@ -68,6 +68,7 @@ message SpectrogramCalculatorOptions { HANN = 0; HAMMING = 1; COSINE = 2; + SQRT_HANN = 4; } optional WindowType window_type = 6 [default = HANN]; From 7a7cc77a8154c6ac873763d39e54b14ae4de403a Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Thu, 19 Jan 2023 07:17:40 -0800 Subject: [PATCH 411/469] Internal change PiperOrigin-RevId: 503157344 --- .../unpack_media_sequence_calculator_test.cc | 2 +- .../framework/calculator_context_test.cc | 4 ++-- mediapipe/framework/port/proto_ns.h | 5 +++-- .../framework/profiler/graph_profiler_test.cc | 18 ++++++++------- .../framework/tool/options_lib_template.cc | 2 +- mediapipe/framework/tool/options_registry.cc | 22 ++++++++++--------- mediapipe/framework/tool/options_registry.h | 2 +- 7 files changed, 30 insertions(+), 25 deletions(-) diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index d8562ffc4..fbf775403 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -647,7 +647,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) { TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { // TODO: Suport proto3 proto.Any in CalculatorOptions. - // TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS". + // TODO: Avoid google::protobuf extensions in "RESAMPLER_OPTIONS". CalculatorOptions options; options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) ->set_padding_before_label(1); diff --git a/mediapipe/framework/calculator_context_test.cc b/mediapipe/framework/calculator_context_test.cc index e7612501a..be9103b4d 100644 --- a/mediapipe/framework/calculator_context_test.cc +++ b/mediapipe/framework/calculator_context_test.cc @@ -131,10 +131,10 @@ TEST(CalculatorTest, GetOptions) { auto calculator_state_3 = MakeCalculatorState(config.node(3), 3); auto cc_3 = MakeCalculatorContext(&*calculator_state_3); - // Get a proto2 options extension from Node::options. + // Get a google::protobuf options extension from Node::options. EXPECT_EQ(cc_0->Options().jitter(), 0.123); - // Get a proto2 options extension from Node::node_options. + // Get a google::protobuf options extension from Node::node_options. EXPECT_EQ(cc_1->Options().jitter(), 0.123); // Get a proto3 options protobuf::Any from Node::node_options. diff --git a/mediapipe/framework/port/proto_ns.h b/mediapipe/framework/port/proto_ns.h index 83aecdf49..53b854ff7 100644 --- a/mediapipe/framework/port/proto_ns.h +++ b/mediapipe/framework/port/proto_ns.h @@ -17,8 +17,9 @@ #include -// Temporary forward declarations for proto2 support on portable targets. -// Use proto_ns inside namespace mediapipe instead of proto2 namespace. +// Temporary forward declarations for google::protobuf support on portable +// targets. Use proto_ns inside namespace mediapipe instead of google::protobuf +// namespace. #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/repeated_field.h" diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 75d1c7ebd..e9badaa25 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -39,13 +39,15 @@ constexpr char kDummyTestCalculatorName[] = "DummyTestCalculator"; CalculatorGraphConfig::Node CreateNodeConfig( const std::string& raw_node_config) { CalculatorGraphConfig::Node node_config; - QCHECK(proto2::TextFormat::ParseFromString(raw_node_config, &node_config)); + QCHECK(google::protobuf::TextFormat::ParseFromString(raw_node_config, + &node_config)); return node_config; } CalculatorGraphConfig CreateGraphConfig(const std::string& raw_graph_config) { CalculatorGraphConfig graph_config; - QCHECK(proto2::TextFormat::ParseFromString(raw_graph_config, &graph_config)); + QCHECK(google::protobuf::TextFormat::ParseFromString(raw_graph_config, + &graph_config)); return graph_config; } @@ -1167,7 +1169,7 @@ TEST_F(GraphProfilerTestPeer, AddProcessSampleWithStreamLatency) { TEST(GraphProfilerTest, ParallelReads) { // A graph that processes a certain number of packets before finishing. CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true } @@ -1189,7 +1191,7 @@ TEST(GraphProfilerTest, ParallelReads) { } output_stream: "OUT:0:the_integers" )", - &config)); + &config)); // Start running the graph on its own threads. absl::Mutex out_1_mutex; @@ -1246,7 +1248,7 @@ std::set GetCalculatorNames(const CalculatorGraphConfig& config) { TEST(GraphProfilerTest, CalculatorProfileFilter) { CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true } @@ -1268,7 +1270,7 @@ TEST(GraphProfilerTest, CalculatorProfileFilter) { } output_stream: "OUT:0:the_integers" )", - &config)); + &config)); std::set expected_names; expected_names = {"RangeCalculator", "PassThroughCalculator"}; @@ -1295,7 +1297,7 @@ TEST(GraphProfilerTest, CalculatorProfileFilter) { TEST(GraphProfilerTest, CaptureProfilePopulateConfig) { CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true trace_enabled: true @@ -1310,7 +1312,7 @@ TEST(GraphProfilerTest, CaptureProfilePopulateConfig) { input_stream: "input_stream" } )", - &config)); + &config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); GraphProfile profile; diff --git a/mediapipe/framework/tool/options_lib_template.cc b/mediapipe/framework/tool/options_lib_template.cc index 21a5db10f..4861132a2 100644 --- a/mediapipe/framework/tool/options_lib_template.cc +++ b/mediapipe/framework/tool/options_lib_template.cc @@ -28,7 +28,7 @@ constexpr char kDescriptorContents[] = mediapipe::FieldData ReadFileDescriptorSet(const std::string& pb) { mediapipe::FieldData result; *result.mutable_message_value()->mutable_type_url() = - "proto2.FileDescriptorSet"; + "google::protobuf.FileDescriptorSet"; *result.mutable_message_value()->mutable_value() = pb; // Force linking of the generated options protobuf. diff --git a/mediapipe/framework/tool/options_registry.cc b/mediapipe/framework/tool/options_registry.cc index f6858be0a..07cc65a95 100644 --- a/mediapipe/framework/tool/options_registry.cc +++ b/mediapipe/framework/tool/options_registry.cc @@ -66,26 +66,28 @@ std::string GetFieldString(const FieldData& message_data, void RegisterDescriptorProtos( absl::flat_hash_map& result) { std::vector descriptors = { - {"proto2.FileDescriptorSet", + {"google::protobuf.FileDescriptorSet", { - {"file", 1, FieldType::TYPE_MESSAGE, "proto2.FileDescriptorProto"}, + {"file", 1, FieldType::TYPE_MESSAGE, + "google::protobuf.FileDescriptorProto"}, }}, - {"proto2.FileDescriptorProto", + {"google::protobuf.FileDescriptorProto", { {"package", 2, FieldType::TYPE_STRING, ""}, {"message_type", 4, FieldType::TYPE_MESSAGE, - "proto2.DescriptorProto"}, + "google::protobuf.DescriptorProto"}, }}, - {"proto2.DescriptorProto", + {"google::protobuf.DescriptorProto", { {"name", 1, FieldType::TYPE_STRING, ""}, - {"field", 2, FieldType::TYPE_MESSAGE, "proto2.FieldDescriptorProto"}, + {"field", 2, FieldType::TYPE_MESSAGE, + "google::protobuf.FieldDescriptorProto"}, {"extension", 6, FieldType::TYPE_MESSAGE, - "proto2.FieldDescriptorProto"}, + "google::protobuf.FieldDescriptorProto"}, {"nested_type", 3, FieldType::TYPE_MESSAGE, - "proto2.DescriptorProto"}, + "google::protobuf.DescriptorProto"}, }}, - {"proto2.FieldDescriptorProto", + {"google::protobuf.FieldDescriptorProto", { {"name", 1, FieldType::TYPE_STRING, ""}, {"number", 3, FieldType::TYPE_INT32, ""}, @@ -140,7 +142,7 @@ void OptionsRegistry::Register(const FieldData& message_type, const Descriptor* OptionsRegistry::GetProtobufDescriptor( const std::string& type_name) { - if (descriptors().count("proto2.DescriptorProto") == 0) { + if (descriptors().count("google::protobuf.DescriptorProto") == 0) { RegisterDescriptorProtos(descriptors()); } absl::ReaderMutexLock lock(&mutex()); diff --git a/mediapipe/framework/tool/options_registry.h b/mediapipe/framework/tool/options_registry.h index b843b113a..3b2d2be89 100644 --- a/mediapipe/framework/tool/options_registry.h +++ b/mediapipe/framework/tool/options_registry.h @@ -28,7 +28,7 @@ class OptionsRegistry { // Finds the descriptor for a protobuf. static const Descriptor* GetProtobufDescriptor(const std::string& type_name); - // Returns all known proto2 extensions to a type. + // Returns all known google::protobuf extensions to a type. static void FindAllExtensions(absl::string_view extendee, std::vector* result); From dcd2adad532f3f65703a7c387f182090a1229c51 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 19 Jan 2023 09:17:39 -0800 Subject: [PATCH 412/469] Removing broken links. They might not be relevant since we only support TfLite models. PiperOrigin-RevId: 503183358 --- docs/solutions/models.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/solutions/models.md b/docs/solutions/models.md index 18bcf0c8b..325c41f1b 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -94,8 +94,6 @@ one over the other. * [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite) * [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) -* [TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) -* [Model information](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md) ### [Objectron](https://google.github.io/mediapipe/solutions/objectron) From a02097ea083cc318d33edc236a3824a0d50002a8 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 19 Jan 2023 10:06:42 -0800 Subject: [PATCH 413/469] Fix comments PiperOrigin-RevId: 503195768 --- mediapipe/tasks/web/audio/index.ts | 2 +- mediapipe/tasks/web/text/index.ts | 2 +- mediapipe/tasks/web/vision/index.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 44fa7eb25..e7465878b 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -18,7 +18,7 @@ import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/a import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; -// Declare the variables locally so that Rollup in OSS includes them explcilty +// Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const AudioClassifier = AudioClassifierImpl; const AudioEmbedder = AudioEmbedderImpl; diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index 2c9e6fead..cfa990e58 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -18,7 +18,7 @@ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fi import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; -// Declare the variables locally so that Rollup in OSS includes them explcilty +// Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; const TextClassifier = TextClassifierImpl; diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index e13f8183f..49f23c243 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -21,7 +21,7 @@ import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/ import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; -// Declare the variables locally so that Rollup in OSS includes them explcilty +// Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; const GestureRecognizer = GestureRecognizerImpl; From db1a89324e6ffc100bf7723fbfaf2673b4f36ecc Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 19 Jan 2023 10:39:05 -0800 Subject: [PATCH 414/469] Add mediapipe::Image output to the graph runner PiperOrigin-RevId: 503204918 --- .../graph_runner/graph_runner_image_lib.ts | 87 ++++++++++++++++--- 1 file changed, 74 insertions(+), 13 deletions(-) diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts index 7a4ea09e2..9608ebcc7 100644 --- a/mediapipe/web/graph_runner/graph_runner_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -1,4 +1,6 @@ -import {ImageSource, GraphRunner} from './graph_runner'; +import {GraphRunner, ImageSource} from './graph_runner'; + + /** * We extend from a GraphRunner constructor. This ensures our mixin has @@ -8,6 +10,12 @@ import {ImageSource, GraphRunner} from './graph_runner'; // tslint:disable-next-line:no-any type LibConstructor = new (...args: any[]) => GraphRunner; +/** An image returned from a MediaPipe graph. */ +export interface WasmImage { + data: Uint8Array|Float32Array; + width: number; + height: number; +} /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -16,26 +24,33 @@ export declare interface WasmImageModule { _addBoundTextureAsImageToStream: (streamNamePtr: number, width: number, height: number, timestamp: number) => void; + _attachImageListener: (streamNamePtr: number) => void; + _attachImageVectorListener: (streamNamePtr: number) => void; } /** * An implementation of GraphRunner that supports binding GPU image data as - * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for - * effective multiple inheritance. Example usage: - * `const GraphRunnerImageLib = SupportImage(GraphRunner);` + * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow + * for effective multiple inheritance. Example usage: `const GraphRunnerImageLib + * = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { return class extends Base { + get wasmImageModule(): WasmImageModule { + return this.wasmModule as unknown as WasmImageModule; + } + /** - * Takes the relevant information from the HTML video or image element, and - * passes it into the WebGL-based graph for processing on the given stream - * at the given timestamp as a MediaPipe image. Processing will not occur - * until a blocking call (like processVideoGl or finishProcessing) is made. + * Takes the relevant information from the HTML video or image element, + * and passes it into the WebGL-based graph for processing on the given + * stream at the given timestamp as a MediaPipe image. Processing will not + * occur until a blocking call (like processVideoGl or finishProcessing) + * is made. * @param imageSource Reference to the video frame we wish to add into our * graph. - * @param streamName The name of the MediaPipe graph stream to add the frame - * to. + * @param streamName The name of the MediaPipe graph stream to add the + * frame to. * @param timestamp The timestamp of the input frame, in ms. */ addGpuBufferAsImageToStream( @@ -43,9 +58,55 @@ export function SupportImage(Base: TBase) { this.wrapStringPtr(streamName, (streamNamePtr: number) => { const [width, height] = this.bindTextureToStream(imageSource, streamNamePtr); - (this.wasmModule as unknown as WasmImageModule) - ._addBoundTextureAsImageToStream( - streamNamePtr, width, height, timestamp); + this.wasmImageModule._addBoundTextureAsImageToStream( + streamNamePtr, width, height, timestamp); + }); + } + + /** + * Attaches a mediapipe:Image packet listener to the specified output + * stream. + * @param outputStreamName The name of the graph output stream to grab + * mediapipe::Image data from. + * @param callbackFcn The function that will be called back with the data, + * as it is received. Note that the data is only guaranteed to exist + * for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. + */ + attachImageListener( + outputStreamName: string, + callbackFcn: (data: WasmImage, timestamp: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for mediapipe::Image packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmImageModule._attachImageListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a mediapipe:Image[] packet listener to the specified + * output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, + * as it is received. Note that the data is only guaranteed to exist + * for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. + */ + attachImageVectorListener( + outputStreamName: string, + callbackFcn: (data: WasmImage[], timestamp: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on + // this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmImageModule._attachImageVectorListener(outputStreamNamePtr); }); } }; From 921b6a6befae381ba873fb61ba170d902a1c6b02 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 19 Jan 2023 22:09:55 -0800 Subject: [PATCH 415/469] This CL will fix the typo from _PALM_LANMARKS to _PALM_LANDMARKS. PiperOrigin-RevId: 503352055 --- mediapipe/python/solutions/drawing_styles.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/python/solutions/drawing_styles.py b/mediapipe/python/solutions/drawing_styles.py index b43bca8d3..5d75d5b30 100644 --- a/mediapipe/python/solutions/drawing_styles.py +++ b/mediapipe/python/solutions/drawing_styles.py @@ -37,9 +37,10 @@ _THICKNESS_FINGER = 2 _THICKNESS_DOT = -1 # Hand landmarks -_PALM_LANMARKS = (HandLandmark.WRIST, HandLandmark.THUMB_CMC, - HandLandmark.INDEX_FINGER_MCP, HandLandmark.MIDDLE_FINGER_MCP, - HandLandmark.RING_FINGER_MCP, HandLandmark.PINKY_MCP) +_PALM_LANDMARKS = (HandLandmark.WRIST, HandLandmark.THUMB_CMC, + HandLandmark.INDEX_FINGER_MCP, + HandLandmark.MIDDLE_FINGER_MCP, HandLandmark.RING_FINGER_MCP, + HandLandmark.PINKY_MCP) _THUMP_LANDMARKS = (HandLandmark.THUMB_MCP, HandLandmark.THUMB_IP, HandLandmark.THUMB_TIP) _INDEX_FINGER_LANDMARKS = (HandLandmark.INDEX_FINGER_PIP, @@ -54,7 +55,7 @@ _RING_FINGER_LANDMARKS = (HandLandmark.RING_FINGER_PIP, _PINKY_FINGER_LANDMARKS = (HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP) _HAND_LANDMARK_STYLE = { - _PALM_LANMARKS: + _PALM_LANDMARKS: DrawingSpec( color=_RED, thickness=_THICKNESS_DOT, circle_radius=_RADIUS), _THUMP_LANDMARKS: From 1124569c29edad16e86a77e57407ca7abf0dc4a2 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Mon, 23 Jan 2023 10:58:14 -0800 Subject: [PATCH 416/469] Tensor: Make tensor not requiring "-x objective-c++" option. In this case tensor.h is compiled differently for C++ and Objective-C++ that violates ODR (once definition rule). Tensor has no virtual methods conditionally compiled but some Metal-related data members. Instead, unique_ptr to MtlResources that is declared as forward structure is unconditionally defined in the tensor class. MtlResources is defined differently in cc-file only that compiled just once per project so no ODR violation is here. PiperOrigin-RevId: 504029286 --- mediapipe/calculators/tensor/BUILD | 81 +----------- .../tensor/image_to_tensor_converter_metal.cc | 6 +- .../tensor/inference_calculator_metal.cc | 18 ++- .../tensor/tensor_converter_calculator.cc | 3 +- .../tensors_to_detections_calculator.cc | 23 ++-- .../tensors_to_segmentation_calculator.cc | 4 +- mediapipe/framework/formats/BUILD | 5 +- mediapipe/framework/formats/tensor.cc | 125 ++++++++++-------- mediapipe/framework/formats/tensor.h | 46 +------ .../formats/tensor_mtl_buffer_view.h | 61 +++++++++ .../tasks/cc/components/calculators/BUILD | 8 -- .../tasks/cc/components/processors/BUILD | 16 --- 12 files changed, 177 insertions(+), 219 deletions(-) create mode 100644 mediapipe/framework/formats/tensor_mtl_buffer_view.h diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 127280107..69d666092 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -53,14 +53,6 @@ mediapipe_proto_library( cc_library( name = "audio_to_tensor_calculator", srcs = ["audio_to_tensor_calculator.cc"], - copts = select({ - # b/215212850 - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", - ], - "//conditions:default": [], - }), deps = [ ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -161,14 +153,6 @@ mediapipe_proto_library( cc_library( name = "feedback_tensors_calculator", srcs = ["feedback_tensors_calculator.cc"], - copts = select({ - # b/215212850 - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", - ], - "//conditions:default": [], - }), deps = [ ":feedback_tensors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -207,14 +191,6 @@ mediapipe_proto_library( cc_library( name = "bert_preprocessor_calculator", srcs = ["bert_preprocessor_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":bert_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -267,14 +243,6 @@ mediapipe_proto_library( cc_library( name = "regex_preprocessor_calculator", srcs = ["regex_preprocessor_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":regex_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -316,14 +284,6 @@ cc_test( cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -414,14 +374,6 @@ cc_library( name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], hdrs = ["inference_calculator.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_calculator_cc_proto", ":inference_calculator_options_lib", @@ -495,6 +447,7 @@ cc_library( tags = ["ios"], deps = [ "inference_calculator_interface", + "//mediapipe/framework/formats:tensor", "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:gpu_buffer", @@ -513,14 +466,6 @@ cc_library( cc_library( name = "inference_runner", hdrs = ["inference_runner.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:tensor", @@ -532,14 +477,6 @@ cc_library( name = "inference_interpreter_delegate_runner", srcs = ["inference_interpreter_delegate_runner.cc"], hdrs = ["inference_interpreter_delegate_runner.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", @@ -561,14 +498,6 @@ cc_library( srcs = [ "inference_calculator_cpu.cc", ], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -607,14 +536,6 @@ cc_library( srcs = [ "inference_calculator_xnnpack.cc", ], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_calculator_interface", ":inference_calculator_utils", diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index a8211d39b..354547042 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -36,6 +36,10 @@ #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#if MEDIAPIPE_METAL_ENABLED +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" +#endif // MEDIAPIPE_METAL_ENABLED + namespace mediapipe { namespace { @@ -376,7 +380,7 @@ class MetalProcessor : public ImageToTensorConverter { id command_buffer = [metal_helper_ commandBuffer]; const auto& buffer_view = - output_tensor.GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensor, command_buffer); MP_RETURN_IF_ERROR(extractor_->Execute( texture, roi, /*flip_horizontaly=*/false, transform.scale, transform.offset, diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index 750f0456e..fba18a81c 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -24,6 +24,8 @@ #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/gpu_buffer.h" @@ -150,11 +152,12 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { command_buffer.label = @"InferenceCalculator"; // Explicit copy input with conversion float 32 bits to 16 bits. for (int i = 0; i < input_tensors.size(); ++i) { - auto input_view = input_tensors[i].GetMtlBufferReadView(command_buffer); + auto input_view = + MtlBufferView::GetReadView(input_tensors[i], command_buffer); // Reshape tensor. tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape()); auto gpu_buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*gpu_buffers_in_[i], command_buffer); id input_encoder = [command_buffer computeCommandEncoder]; [converter_to_BPHWC4_ convertWithEncoder:input_encoder @@ -174,9 +177,10 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { output_shapes_[i]); // Reshape tensor. tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]); - auto read_view = gpu_buffers_out_[i]->GetMtlBufferReadView(command_buffer); + auto read_view = + MtlBufferView::GetReadView(*gpu_buffers_out_[i], command_buffer); auto write_view = - output_tensors->at(i).GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensors->at(i), command_buffer); id output_encoder = [command_buffer computeCommandEncoder]; [converter_from_BPHWC4_ convertWithEncoder:output_encoder @@ -258,7 +262,7 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( : Tensor::ElementType::kFloat32, Tensor::Shape{dims})); auto buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); + MtlBufferView::GetWriteView(*gpu_buffers_in_[i], gpu_helper_.mtlDevice); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_.get(), input_indices[i], buffer_view.buffer()), true); @@ -286,8 +290,8 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( Tensor::Shape{dims})); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_.get(), output_indices[i], - gpu_buffers_out_[i] - ->GetMtlBufferWriteView(gpu_helper_.mtlDevice) + MtlBufferView::GetWriteView(*gpu_buffers_out_[i], + gpu_helper_.mtlDevice) .buffer()), true); } diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index 0b750b859..4b05488fd 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -31,6 +31,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_calculator_helper.h" @@ -304,7 +305,7 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { id src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input]; [compute_encoder setTexture:src_texture atIndex:0]; auto output_view = - output_tensors->at(0).GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensors->at(0), command_buffer); [compute_encoder setBuffer:output_view.buffer() offset:0 atIndex:1]; MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); MTLSize threadgroups = diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 97ef01b4c..4bb3f0f57 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -41,6 +41,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #endif // MEDIAPIPE_METAL_ENABLED @@ -536,10 +537,11 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( if (input_tensors.size() == kNumInputTensorsWithAnchors) { RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); auto command_buffer = [gpu_helper_ commandBuffer]; - auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto src_buffer = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.anchors_tensor_index()], + command_buffer); auto dest_buffer = - raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*raw_anchors_buffer_, command_buffer); id blit_command = [command_buffer blitCommandEncoder]; [blit_command copyFromBuffer:src_buffer.buffer() @@ -571,15 +573,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( [command_encoder setComputePipelineState:decode_program_]; { auto scored_boxes_view = - scored_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*scored_boxes_buffer_, command_buffer); auto decoded_boxes_view = - decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*decoded_boxes_buffer_, command_buffer); [command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0]; - auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto input0_view = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.detections_tensor_index()], + command_buffer); [command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1]; auto raw_anchors_view = - raw_anchors_buffer_->GetMtlBufferReadView(command_buffer); + MtlBufferView::GetReadView(*raw_anchors_buffer_, command_buffer); [command_encoder setBuffer:raw_anchors_view.buffer() offset:0 atIndex:2]; MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1); MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1); @@ -588,8 +591,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( [command_encoder setComputePipelineState:score_program_]; [command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0]; - auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto input1_view = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.scores_tensor_index()], command_buffer); [command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1]; MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1); MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1); diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 172f70880..839451ab7 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -53,6 +53,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #endif // MEDIAPIPE_METAL_ENABLED @@ -485,7 +486,8 @@ absl::Status TensorsToSegmentationCalculator::ProcessGpu( [command_buffer computeCommandEncoder]; [command_encoder setComputePipelineState:mask_program_]; - auto read_view = input_tensors[0].GetMtlBufferReadView(command_buffer); + auto read_view = + MtlBufferView::GetReadView(input_tensors[0], command_buffer); [command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0]; mediapipe::GpuBuffer small_mask_buffer = [metal_helper_ diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 371f23ed1..10aa3fca0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -431,7 +431,10 @@ cc_library( hdrs = [ "tensor.h", "//mediapipe/framework/formats/tensor:internal.h", - ], + ] + select({ + "//mediapipe:ios": ["tensor_mtl_buffer_view.h"], + "//conditions:default": [], + }), copts = select({ "//mediapipe:apple": [ "-x objective-c++", diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 3f11d368a..1dbd8f8ac 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -25,8 +25,11 @@ #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_METAL_ENABLED +#import #include #include + +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #else #include #endif // MEDIAPIPE_METAL_ENABLED @@ -61,6 +64,12 @@ int BhwcDepthFromShape(const Tensor::Shape& shape) { // 3) pad/"unpad" the bitmap after transfer CPU <-> GPU #if MEDIAPIPE_METAL_ENABLED +// No ODR violation here because this file compiled just once per project. +struct MtlResources { + id command_buffer = nil; + id device = nil; + id metal_buffer = nil; +}; namespace { // MTLBuffer can use existing properly aligned and allocated CPU memory. size_t AlignToPageSize(size_t size) { @@ -83,52 +92,56 @@ void DeallocateVirtualMemory(void* pointer, size_t size) { } } // namespace -Tensor::MtlBufferView Tensor::GetMtlBufferReadView( - id command_buffer) const { - LOG_IF(FATAL, valid_ == kValidNone) +void MtlBufferView::AllocateMtlBuffer(const Tensor& tensor, + id device) { + tensor.mtl_resources_->device = device; + if (!tensor.cpu_buffer_) { + // It also means that the metal buffer is not allocated yet. + tensor.cpu_buffer_ = AllocateVirtualMemory(tensor.bytes()); + } + if (!tensor.mtl_resources_->metal_buffer) { + tensor.mtl_resources_->metal_buffer = [tensor.mtl_resources_->device + newBufferWithBytesNoCopy:tensor.cpu_buffer_ + length:AlignToPageSize(tensor.bytes()) + options:MTLResourceStorageModeShared | + MTLResourceCPUCacheModeDefaultCache + deallocator:^(void* pointer, NSUInteger length) { + DeallocateVirtualMemory(pointer, length); + }]; + } +} + +MtlBufferView MtlBufferView::GetReadView(const Tensor& tensor, + id command_buffer) { + LOG_IF(FATAL, tensor.valid_ == Tensor::kValidNone) << "Tensor must be written prior to read from."; - LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidMetalBuffer))) + LOG_IF(FATAL, + !(tensor.valid_ & (Tensor::kValidCpu | Tensor::kValidMetalBuffer))) << "Tensor conversion between different GPU resources is not supported " "yet."; - auto lock(absl::make_unique(&view_mutex_)); - valid_ |= kValidMetalBuffer; - AllocateMtlBuffer([command_buffer device]); - return {metal_buffer_, std::move(lock)}; + auto lock(absl::make_unique(&tensor.view_mutex_)); + tensor.valid_ |= Tensor::kValidMetalBuffer; + AllocateMtlBuffer(tensor, [command_buffer device]); + return {tensor.mtl_resources_->metal_buffer, std::move(lock)}; } -Tensor::MtlBufferView Tensor::GetMtlBufferWriteView( - id command_buffer) const { +MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor, + id command_buffer) { // Don't overwrite command buffer at which the metal buffer has been written // so we can wait until completed. - command_buffer_ = command_buffer; - return GetMtlBufferWriteView([command_buffer device]); + tensor.mtl_resources_->command_buffer = command_buffer; + return GetWriteView(tensor, [command_buffer device]); } -Tensor::MtlBufferView Tensor::GetMtlBufferWriteView( - id device) const { - auto lock(absl::make_unique(&view_mutex_)); - valid_ = kValidMetalBuffer; - AllocateMtlBuffer(device); - return {metal_buffer_, std::move(lock)}; -} - -void Tensor::AllocateMtlBuffer(id device) const { - device_ = device; - if (!cpu_buffer_) { - // It also means that the metal buffer is not allocated yet. - cpu_buffer_ = AllocateVirtualMemory(bytes()); - } - if (!metal_buffer_) { - metal_buffer_ = - [device_ newBufferWithBytesNoCopy:cpu_buffer_ - length:AlignToPageSize(bytes()) - options:MTLResourceStorageModeShared | - MTLResourceCPUCacheModeDefaultCache - deallocator:^(void* pointer, NSUInteger length) { - DeallocateVirtualMemory(pointer, length); - }]; - } +MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor, + id device) { + auto lock(absl::make_unique(&tensor.view_mutex_)); + tensor.valid_ = Tensor::kValidMetalBuffer; + AllocateMtlBuffer(tensor, device); + return {tensor.mtl_resources_->metal_buffer, std::move(lock)}; } +#else +struct MtlResources {}; #endif // MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 @@ -379,6 +392,9 @@ Tensor& Tensor::operator=(Tensor&& src) { return *this; } +Tensor::Tensor(Tensor&& src) { Move(&src); } +Tensor::~Tensor() { Invalidate(); } + void Tensor::Move(Tensor* src) { valid_ = src->valid_; src->valid_ = kValidNone; @@ -388,15 +404,7 @@ void Tensor::Move(Tensor* src) { cpu_buffer_ = src->cpu_buffer_; src->cpu_buffer_ = nullptr; ahwb_tracking_key_ = src->ahwb_tracking_key_; -#if MEDIAPIPE_METAL_ENABLED - device_ = src->device_; - src->device_ = nil; - command_buffer_ = src->command_buffer_; - src->command_buffer_ = nil; - metal_buffer_ = src->metal_buffer_; - src->metal_buffer_ = nil; -#endif // MEDIAPIPE_METAL_ENABLED - + mtl_resources_ = std::move(src->mtl_resources_); MoveAhwbStuff(src); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 @@ -415,12 +423,15 @@ void Tensor::Move(Tensor* src) { } Tensor::Tensor(ElementType element_type, const Shape& shape) - : element_type_(element_type), shape_(shape) {} + : element_type_(element_type), + shape_(shape), + mtl_resources_(std::make_unique()) {} Tensor::Tensor(ElementType element_type, const Shape& shape, const QuantizationParameters& quantization_parameters) : element_type_(element_type), shape_(shape), - quantization_parameters_(quantization_parameters) {} + quantization_parameters_(quantization_parameters), + mtl_resources_(std::make_unique()) {} #if MEDIAPIPE_METAL_ENABLED void Tensor::Invalidate() { @@ -432,13 +443,16 @@ void Tensor::Invalidate() { absl::MutexLock lock(&view_mutex_); // If memory is allocated and not owned by the metal buffer. // TODO: Re-design cpu buffer memory management. - if (cpu_buffer_ && !metal_buffer_) { + if (cpu_buffer_ && !mtl_resources_->metal_buffer) { DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); } - metal_buffer_ = nil; - command_buffer_ = nil; - device_ = nil; cpu_buffer_ = nullptr; + // This becomes NULL if the tensor is moved. + if (mtl_resources_) { + mtl_resources_->metal_buffer = nil; + mtl_resources_->command_buffer = nil; + mtl_resources_->device = nil; + } #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 // Don't need to wait for the resource to be deleted bacause if will be // released on last reference deletion inside the OpenGL driver. @@ -532,10 +546,11 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { // GPU-to-CPU synchronization and read-back. #if MEDIAPIPE_METAL_ENABLED if (valid_ & kValidMetalBuffer) { - LOG_IF(FATAL, !command_buffer_) << "Metal -> CPU synchronization " - "requires MTLCommandBuffer to be set."; - if (command_buffer_) { - [command_buffer_ waitUntilCompleted]; + LOG_IF(FATAL, !mtl_resources_->command_buffer) + << "Metal -> CPU synchronization " + "requires MTLCommandBuffer to be set."; + if (mtl_resources_->command_buffer) { + [mtl_resources_->command_buffer waitUntilCompleted]; } } #endif // MEDIAPIPE_METAL_ENABLED diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index fe0be31d1..1d670d805 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -29,9 +29,6 @@ #include "mediapipe/framework/formats/tensor/internal.h" #include "mediapipe/framework/port.h" -#if MEDIAPIPE_METAL_ENABLED -#import -#endif // MEDIAPIPE_METAL_ENABLED #ifndef MEDIAPIPE_NO_JNI #if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) #define MEDIAPIPE_TENSOR_USE_AHWB 1 @@ -66,7 +63,6 @@ #endif namespace mediapipe { - // Tensor is a container of multi-dimensional data that supports sharing the // content across different backends and APIs, currently: CPU / Metal / OpenGL. // Texture2DView is limited to 4 dimensions. @@ -91,6 +87,7 @@ namespace mediapipe { // float* pointer = view.buffer(); // ...reading the cpu memory... +struct MtlResources; class Tensor { class View { public: @@ -144,9 +141,9 @@ class Tensor { Tensor(const Tensor&) = delete; Tensor& operator=(const Tensor&) = delete; // Move-only. - Tensor(Tensor&& src) { Move(&src); } + Tensor(Tensor&& src); Tensor& operator=(Tensor&&); - ~Tensor() { Invalidate(); } + ~Tensor(); template class CpuView : public View { @@ -182,33 +179,6 @@ class Tensor { uint64_t source_location_hash = tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; -#if MEDIAPIPE_METAL_ENABLED - // TODO: id vs. MtlBufferView. - class MtlBufferView : public View { - public: - id buffer() const { return buffer_; } - MtlBufferView(MtlBufferView&& src) - : View(std::move(src)), buffer_(src.buffer_) { - src.buffer_ = nil; - } - - protected: - friend class Tensor; - MtlBufferView(id buffer, std::unique_ptr&& lock) - : View(std::move(lock)), buffer_(buffer) {} - id buffer_; - }; - // The command buffer status is checked for completeness if GPU-to-CPU - // synchronization is required. - // TODO: Design const and non-const view acquiring. - MtlBufferView GetMtlBufferReadView(id command_buffer) const; - MtlBufferView GetMtlBufferWriteView( - id command_buffer) const; - // Allocate new buffer. - // TODO: GPU-to-CPU design considerations. - MtlBufferView GetMtlBufferWriteView(id device) const; -#endif // MEDIAPIPE_METAL_ENABLED - #ifdef MEDIAPIPE_TENSOR_USE_AHWB using FinishingFunc = std::function; class AHardwareBufferView : public View { @@ -372,6 +342,7 @@ class Tensor { } private: + friend class MtlBufferView; void Move(Tensor*); void Invalidate(); @@ -396,12 +367,9 @@ class Tensor { mutable void* cpu_buffer_ = nullptr; void AllocateCpuBuffer() const; -#if MEDIAPIPE_METAL_ENABLED - mutable id command_buffer_ = nil; - mutable id device_ = nil; - mutable id metal_buffer_ = nil; - void AllocateMtlBuffer(id device) const; -#endif // MEDIAPIPE_METAL_ENABLED + // Forward declaration of the MtlResources provides compile-time verification + // of ODR if this header includes any actual code that uses MtlResources. + mutable std::unique_ptr mtl_resources_; #ifdef MEDIAPIPE_TENSOR_USE_AHWB mutable AHardwareBuffer* ahwb_ = nullptr; diff --git a/mediapipe/framework/formats/tensor_mtl_buffer_view.h b/mediapipe/framework/formats/tensor_mtl_buffer_view.h new file mode 100644 index 000000000..a61659d3d --- /dev/null +++ b/mediapipe/framework/formats/tensor_mtl_buffer_view.h @@ -0,0 +1,61 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ + +#import + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" + +namespace mediapipe { +class MtlBufferView : public Tensor::View { + public: + // The command buffer status is checked for completeness if GPU-to-CPU + // synchronization is required. + static MtlBufferView GetReadView(const Tensor& tensor, + id command_buffer); + static MtlBufferView GetWriteView(const Tensor& tensor, + id command_buffer); + static MtlBufferView GetWriteView(const Tensor& tensor, id device); + + id buffer() const { return buffer_; } + MtlBufferView(MtlBufferView&& src) + : Tensor::View(std::move(src)), buffer_(src.buffer_) { + src.buffer_ = nil; + } + + protected: + friend class Tensor; + static void AllocateMtlBuffer(const Tensor& tensor, id device); + MtlBufferView(id buffer, std::unique_ptr&& lock) + : Tensor::View(std::move(lock)), buffer_(buffer) {} + id buffer_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index bf31134e4..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -79,14 +79,6 @@ mediapipe_proto_library( cc_library( name = "score_calibration_calculator", srcs = ["score_calibration_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":score_calibration_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 517a27114..cec44a9e3 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -28,14 +28,6 @@ cc_library( name = "classification_postprocessing_graph", srcs = ["classification_postprocessing_graph.cc"], hdrs = ["classification_postprocessing_graph.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator_cc_proto", @@ -148,14 +140,6 @@ cc_library( name = "text_preprocessing_graph", srcs = ["text_preprocessing_graph.cc"], hdrs = ["text_preprocessing_graph.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/calculators/tensor:bert_preprocessor_calculator", "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", From 69d354fc89173035007280daef793b7a640542fe Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 23 Jan 2023 12:09:41 -0800 Subject: [PATCH 417/469] Use c++ struct as hand landmark detection results. PiperOrigin-RevId: 504048095 --- .../tasks/cc/components/containers/BUILD | 9 ++ .../containers/classification_result.cc | 13 +++ .../containers/classification_result.h | 7 ++ .../cc/components/containers/landmark.cc | 65 +++++++++++ .../tasks/cc/components/containers/landmark.h | 103 ++++++++++++++++++ .../tasks/cc/vision/hand_landmarker/BUILD | 3 + .../vision/hand_landmarker/hand_landmarker.cc | 60 ++++++---- .../hand_landmarker/hand_landmarker_result.cc | 56 ++++++++++ .../hand_landmarker/hand_landmarker_result.h | 15 ++- .../hand_landmarker_result_test.cc | 88 +++++++++++++++ .../hand_landmarker/hand_landmarker_test.cc | 62 ++++++++--- 11 files changed, 439 insertions(+), 42 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/landmark.cc create mode 100644 mediapipe/tasks/cc/components/containers/landmark.h create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 0750a1482..a7307b2ce 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -62,3 +62,12 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "landmark", + srcs = ["landmark.cc"], + hdrs = ["landmark.h"], + deps = [ + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/containers/classification_result.cc b/mediapipe/tasks/cc/components/containers/classification_result.cc index 98583ff15..f2d88406d 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.cc +++ b/mediapipe/tasks/cc/components/containers/classification_result.cc @@ -40,6 +40,19 @@ Classifications ConvertToClassifications(const proto::Classifications& proto) { return classifications; } +Classifications ConvertToClassifications( + const mediapipe::ClassificationList& proto, int head_index, + std::optional head_name) { + Classifications classifications; + classifications.categories.reserve(proto.classification_size()); + for (const auto& classification : proto.classification()) { + classifications.categories.push_back(ConvertToCategory(classification)); + } + classifications.head_index = head_index; + classifications.head_name = head_name; + return classifications; +} + ClassificationResult ConvertToClassificationResult( const proto::ClassificationResult& proto) { ClassificationResult classification_result; diff --git a/mediapipe/tasks/cc/components/containers/classification_result.h b/mediapipe/tasks/cc/components/containers/classification_result.h index 88273fd00..e359fb33e 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.h +++ b/mediapipe/tasks/cc/components/containers/classification_result.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" @@ -58,6 +59,12 @@ struct ClassificationResult { // Classifications struct. Classifications ConvertToClassifications(const proto::Classifications& proto); +// Utility function to convert from ClassificationList proto to +// Classifications struct. +Classifications ConvertToClassifications( + const mediapipe::ClassificationList& proto, int head_index = 0, + std::optional head_name = std::nullopt); + // Utility function to convert from ClassificationResult proto to // ClassificationResult struct. ClassificationResult ConvertToClassificationResult( diff --git a/mediapipe/tasks/cc/components/containers/landmark.cc b/mediapipe/tasks/cc/components/containers/landmark.cc new file mode 100644 index 000000000..6d80cb835 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.cc @@ -0,0 +1,65 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +#include +#include + +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::components::containers { + +Landmark ConvertToLandmark(const mediapipe::Landmark& proto) { + return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(), + /*visibility=*/proto.has_visibility() + ? std::optional(proto.visibility()) + : std::nullopt, + /*presence=*/proto.has_presence() + ? std::optional(proto.presence()) + : std::nullopt}; +} + +NormalizedLandmark ConvertToNormalizedLandmark( + const mediapipe::NormalizedLandmark& proto) { + return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(), + /*visibility=*/proto.has_visibility() + ? std::optional(proto.visibility()) + : std::nullopt, + /*presence=*/proto.has_presence() + ? std::optional(proto.presence()) + : std::nullopt}; +} + +Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto) { + Landmarks landmarks; + landmarks.landmarks.reserve(proto.landmark_size()); + for (const auto& landmark : proto.landmark()) { + landmarks.landmarks.push_back(ConvertToLandmark(landmark)); + } + return landmarks; +} + +NormalizedLandmarks ConvertToNormalizedLandmarks( + const mediapipe::NormalizedLandmarkList& proto) { + NormalizedLandmarks landmarks; + landmarks.landmarks.reserve(proto.landmark_size()); + for (const auto& landmark : proto.landmark()) { + landmarks.landmarks.push_back(ConvertToNormalizedLandmark(landmark)); + } + return landmarks; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/components/containers/landmark.h new file mode 100644 index 000000000..15b730001 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.h @@ -0,0 +1,103 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::components::containers { +constexpr float kLandmarkTolerance = 1e-6; + +// Landmark represents a point in 3D space with x, y, z coordinates. The +// landmark coordinates are in meters. z represents the landmark depth, and the +// smaller the value the closer the world landmark is to the camera. +struct Landmark { + float x; + float y; + float z; + // Landmark visibility. Should stay unset if not supported. + // Float score of whether landmark is visible or occluded by other objects. + // Landmark considered as invisible also if it is not present on the screen + // (out of scene bounds). Depending on the model, visibility value is either a + // sigmoid or an argument of sigmoid. + std::optional visibility = std::nullopt; + // Landmark presence. Should stay unset if not supported. + // Float score of whether landmark is present on the scene (located within + // scene bounds). Depending on the model, presence value is either a result of + // sigmoid or an argument of sigmoid function to get landmark presence + // probability. + std::optional presence = std::nullopt; + // Landmark name. Should stay unset if not supported. + std::optional name = std::nullopt; +}; + +inline bool operator==(const Landmark& lhs, const Landmark& rhs) { + return abs(lhs.x - rhs.x) < kLandmarkTolerance && + abs(lhs.y - rhs.y) < kLandmarkTolerance && + abs(lhs.z - rhs.z) < kLandmarkTolerance; +} + +// A normalized version of above Landmark struct. All coordinates should be +// within [0, 1]. +struct NormalizedLandmark { + float x; + float y; + float z; + std::optional visibility = std::nullopt; + std::optional presence = std::nullopt; + std::optional name = std::nullopt; +}; + +inline bool operator==(const NormalizedLandmark& lhs, + const NormalizedLandmark& rhs) { + return abs(lhs.x - rhs.x) < kLandmarkTolerance && + abs(lhs.y - rhs.y) < kLandmarkTolerance && + abs(lhs.z - rhs.z) < kLandmarkTolerance; +} + +// A list of Landmarks. +struct Landmarks { + std::vector landmarks; +}; + +// A list of NormalizedLandmarks. +struct NormalizedLandmarks { + std::vector landmarks; +}; + +// Utility function to convert from Landmark proto to Landmark struct. +Landmark ConvertToLandmark(const mediapipe::Landmark& proto); + +// Utility function to convert from NormalizedLandmark proto to +// NormalizedLandmark struct. +NormalizedLandmark ConvertToNormalizedLandmark( + const mediapipe::NormalizedLandmark& proto); + +// Utility function to convert from LandmarkList proto to Landmarks struct. +Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto); + +// Utility function to convert from NormalizedLandmarkList proto to +// NormalizedLandmarks struct. +NormalizedLandmarks ConvertToNormalizedLandmarks( + const mediapipe::NormalizedLandmarkList& proto); + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 03ec45f7d..2552e7a10 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -154,11 +154,14 @@ cc_library( cc_library( name = "hand_landmarker_result", + srcs = ["hand_landmarker_result.cc"], hdrs = ["hand_landmarker_result.h"], visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers:landmark", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 3bb1ee8d8..ab66fe136 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -155,9 +155,13 @@ absl::StatusOr> HandLandmarker::Create( Packet hand_world_landmarks_packet = status_or_packets.value()[kHandWorldLandmarksStreamName]; result_callback( - {{handedness_packet.Get>(), - hand_landmarks_packet.Get>(), - hand_world_landmarks_packet.Get>()}}, + ConvertToHandLandmarkerResult( + /* handedness= */ handedness_packet + .Get>(), + /* hand_landmarks= */ + hand_landmarks_packet.Get>(), + /* hand_world_landmarks= */ + hand_world_landmarks_packet.Get>()), image_packet.Get(), hand_landmarks_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -193,15 +197,21 @@ absl::StatusOr HandLandmarker::Detect( if (output_packets[kHandLandmarksStreamName].IsEmpty()) { return {HandLandmarkerResult()}; } - return {{/* handedness= */ - {output_packets[kHandednessStreamName] - .Get>()}, - /* hand_landmarks= */ - {output_packets[kHandLandmarksStreamName] - .Get>()}, - /* hand_world_landmarks */ - {output_packets[kHandWorldLandmarksStreamName] - .Get>()}}}; + return ConvertToHandLandmarkerResult(/* handedness= */ + output_packets[kHandednessStreamName] + .Get>(), + /* hand_landmarks= */ + output_packets[kHandLandmarksStreamName] + .Get>(), + /* hand_world_landmarks */ + output_packets + [kHandWorldLandmarksStreamName] + .Get>()); } absl::StatusOr HandLandmarker::DetectForVideo( @@ -228,17 +238,21 @@ absl::StatusOr HandLandmarker::DetectForVideo( if (output_packets[kHandLandmarksStreamName].IsEmpty()) { return {HandLandmarkerResult()}; } - return { - {/* handedness= */ - {output_packets[kHandednessStreamName] - .Get>()}, - /* hand_landmarks= */ - {output_packets[kHandLandmarksStreamName] - .Get>()}, - /* hand_world_landmarks */ - {output_packets[kHandWorldLandmarksStreamName] - .Get>()}}, - }; + return ConvertToHandLandmarkerResult(/* handedness= */ + output_packets[kHandednessStreamName] + .Get>(), + /* hand_landmarks= */ + output_packets[kHandLandmarksStreamName] + .Get>(), + /* hand_world_landmarks */ + output_packets + [kHandWorldLandmarksStreamName] + .Get>()); } absl::Status HandLandmarker::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc new file mode 100644 index 000000000..9d2ae2be8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" + +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +HandLandmarkerResult ConvertToHandLandmarkerResult( + const std::vector& handedness_proto, + const std::vector& hand_landmarks_proto, + const std::vector& hand_world_landmarks_proto) { + HandLandmarkerResult result; + result.handedness.resize(handedness_proto.size()); + result.hand_landmarks.resize(hand_landmarks_proto.size()); + result.hand_world_landmarks.resize(hand_world_landmarks_proto.size()); + std::transform(handedness_proto.begin(), handedness_proto.end(), + result.handedness.begin(), + [](const mediapipe::ClassificationList& classification_list) { + return components::containers::ConvertToClassifications( + classification_list); + }); + std::transform(hand_landmarks_proto.begin(), hand_landmarks_proto.end(), + result.hand_landmarks.begin(), + components::containers::ConvertToNormalizedLandmarks); + std::transform(hand_world_landmarks_proto.begin(), + hand_world_landmarks_proto.end(), + result.hand_world_landmarks.begin(), + components::containers::ConvertToLandmarks); + return result; +} + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h index 5e51c244e..1bca8e66a 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" namespace mediapipe { namespace tasks { @@ -28,13 +30,18 @@ namespace hand_landmarker { // element represents a single hand detected in the image. struct HandLandmarkerResult { // Classification of handedness. - std::vector handedness; + std::vector handedness; // Detected hand landmarks in normalized image coordinates. - std::vector hand_landmarks; + std::vector hand_landmarks; // Detected hand landmarks in world coordinates. - std::vector hand_world_landmarks; + std::vector hand_world_landmarks; }; +HandLandmarkerResult ConvertToHandLandmarkerResult( + const std::vector& handedness_proto, + const std::vector& hand_landmarks_proto, + const std::vector& hand_world_landmarks_proto); + } // namespace hand_landmarker } // namespace vision } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc new file mode 100644 index 000000000..109749b01 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" + +#include +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +TEST(ConvertFromProto, Succeeds) { + mediapipe::ClassificationList classification_list_proto; + mediapipe::Classification& classification_proto = + *classification_list_proto.add_classification(); + classification_proto.set_index(1); + classification_proto.set_score(0.5); + classification_proto.set_label("Left"); + classification_proto.set_display_name("Left_Hand"); + + mediapipe::NormalizedLandmarkList normalized_landmark_list_proto; + mediapipe::NormalizedLandmark& normalized_landmark_proto = + *normalized_landmark_list_proto.add_landmark(); + normalized_landmark_proto.set_x(0.1); + normalized_landmark_proto.set_y(0.2); + normalized_landmark_proto.set_z(0.3); + + mediapipe::LandmarkList landmark_list_proto; + mediapipe::Landmark& landmark_proto = *landmark_list_proto.add_landmark(); + landmark_proto.set_x(3.1); + landmark_proto.set_y(5.2); + landmark_proto.set_z(4.3); + + std::vector classification_lists = { + classification_list_proto}; + std::vector normalized_landmarks_lists = { + normalized_landmark_list_proto}; + std::vector landmarks_lists = {landmark_list_proto}; + + HandLandmarkerResult hand_landmarker_result = ConvertToHandLandmarkerResult( + classification_lists, normalized_landmarks_lists, landmarks_lists); + + EXPECT_EQ(hand_landmarker_result.handedness.size(), 1); + EXPECT_EQ(hand_landmarker_result.handedness[0].categories.size(), 1); + EXPECT_THAT( + hand_landmarker_result.handedness[0].categories[0], + testing::FieldsAre(1, testing::FloatEq(0.5), "Left", "Left_Hand")); + + EXPECT_EQ(hand_landmarker_result.hand_landmarks.size(), 1); + EXPECT_EQ(hand_landmarker_result.hand_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(hand_landmarker_result.hand_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(0.1), testing::FloatEq(0.2), + testing::FloatEq(0.3), std::nullopt, + std::nullopt, std::nullopt)); + + EXPECT_EQ(hand_landmarker_result.hand_world_landmarks.size(), 1); + EXPECT_EQ(hand_landmarker_result.hand_world_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(hand_landmarker_result.hand_world_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2), + testing::FloatEq(4.3), std::nullopt, + std::nullopt, std::nullopt)); +} + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index 94d1b1c12..b21f1bee9 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -32,6 +32,8 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" #include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" @@ -50,18 +52,16 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::ConvertToClassifications; +using ::mediapipe::tasks::components::containers::ConvertToNormalizedLandmarks; using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; -using ::testing::EqualsProto; using ::testing::HasSubstr; using ::testing::Optional; -using ::testing::Pointwise; using ::testing::TestParamInfo; using ::testing::TestWithParam; using ::testing::Values; -using ::testing::proto::Approximately; -using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task"; @@ -74,7 +74,6 @@ constexpr char kPointingUpImage[] = "pointing_up.jpg"; constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg"; constexpr char kNoHandsImage[] = "cats_and_dogs.jpg"; -constexpr float kLandmarksFractionDiff = 0.03; // percentage constexpr float kLandmarksAbsMargin = 0.03; constexpr float kHandednessMargin = 0.05; @@ -101,13 +100,47 @@ HandLandmarkerResult GetExpectedHandLandmarkerResult( const auto landmarks_detection_result = GetLandmarksDetectionResult(file_name); expected_results.hand_landmarks.push_back( - landmarks_detection_result.landmarks()); + ConvertToNormalizedLandmarks(landmarks_detection_result.landmarks())); expected_results.handedness.push_back( - landmarks_detection_result.classifications()); + ConvertToClassifications(landmarks_detection_result.classifications())); } return expected_results; } +MATCHER_P2(HandednessMatches, expected_handedness, tolerance, "") { + for (int i = 0; i < arg.size(); i++) { + for (int j = 0; j < arg[i].categories.size(); j++) { + if (arg[i].categories[j].index != + expected_handedness[i].categories[j].index) { + return false; + } + if (std::abs(arg[i].categories[j].score - + expected_handedness[i].categories[j].score) > tolerance) { + return false; + } + if (arg[i].categories[j].category_name != + expected_handedness[i].categories[j].category_name) { + return false; + } + } + } + return true; +} + +MATCHER_P2(LandmarksMatches, expected_landmarks, toleration, "") { + for (int i = 0; i < arg.size(); i++) { + for (int j = 0; j < arg[i].landmarks.size(); j++) { + if (std::abs(arg[i].landmarks[j].x - + expected_landmarks[i].landmarks[j].x) > toleration || + std::abs(arg[i].landmarks[j].y - + expected_landmarks[i].landmarks[j].y) > toleration) { + return false; + } + } + } + return true; +} + void ExpectHandLandmarkerResultsCorrect( const HandLandmarkerResult& actual_results, const HandLandmarkerResult& expected_results) { @@ -119,16 +152,15 @@ void ExpectHandLandmarkerResultsCorrect( ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size()); ASSERT_EQ(actual_handedness.size(), expected_handedness.size()); + if (actual_landmarks.empty()) { + return; + } + ASSERT_GE(actual_landmarks.size(), 1); - EXPECT_THAT( - actual_handedness, - Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin), - expected_handedness)); + EXPECT_THAT(actual_handedness, + HandednessMatches(expected_handedness, kHandednessMargin)); EXPECT_THAT(actual_landmarks, - Pointwise(Approximately(Partially(EqualsProto()), - /*margin=*/kLandmarksAbsMargin, - /*fraction=*/kLandmarksFractionDiff), - expected_landmarks)); + LandmarksMatches(expected_landmarks, kLandmarksAbsMargin)); } } // namespace From ccd1461add4b6ecc974a46df597bcac8c154bbc9 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 23 Jan 2023 13:36:32 -0800 Subject: [PATCH 418/469] Don't error in ExternalFile handler on Windows if FileContent is provided PiperOrigin-RevId: 504069137 --- mediapipe/tasks/cc/core/external_file_handler.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index 33dfeca0b..ff30bea72 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -84,12 +84,6 @@ ExternalFileHandler::CreateFromExternalFile( } absl::Status ExternalFileHandler::MapExternalFile() { -// TODO: Add Windows support -#ifdef _WIN32 - return CreateStatusWithPayload(StatusCode::kFailedPrecondition, - "File loading is not yet supported on Windows", - MediaPipeTasksStatus::kFileReadError); -#else if (!external_file_.file_content().empty()) { return absl::OkStatus(); } else if (external_file_.has_file_pointer_meta()) { @@ -106,6 +100,13 @@ absl::Status ExternalFileHandler::MapExternalFile() { } return absl::OkStatus(); } + +// TODO: Add Windows support +#ifdef _WIN32 + return CreateStatusWithPayload(StatusCode::kFailedPrecondition, + "File loading is not yet supported on Windows", + MediaPipeTasksStatus::kFileReadError); +#else if (external_file_.file_name().empty() && !external_file_.has_file_descriptor_meta()) { return CreateStatusWithPayload( From 873d7181bf60fb29a5a441c8207a219029dafb98 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 23 Jan 2023 14:13:38 -0800 Subject: [PATCH 419/469] Add mediapipe tasks face detector graph PiperOrigin-RevId: 504078951 --- mediapipe/tasks/cc/vision/face_detector/BUILD | 61 +++++ .../face_detector/face_detector_graph.cc | 208 ++++++++++++++++++ .../face_detector/face_detector_graph_test.cc | 183 +++++++++++++++ .../tasks/cc/vision/face_detector/proto/BUILD | 31 +++ .../proto/face_detector_graph_options.proto | 42 ++++ mediapipe/tasks/testdata/vision/BUILD | 8 + .../vision/portrait_expected_detection.pbtxt | 35 +++ third_party/external_files.bzl | 20 +- 8 files changed, 584 insertions(+), 4 deletions(-) create mode 100644 mediapipe/tasks/cc/vision/face_detector/BUILD create mode 100644 mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc create mode 100644 mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc create mode 100644 mediapipe/tasks/cc/vision/face_detector/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto create mode 100644 mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt diff --git a/mediapipe/tasks/cc/vision/face_detector/BUILD b/mediapipe/tasks/cc/vision/face_detector/BUILD new file mode 100644 index 000000000..09af34aa0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/BUILD @@ -0,0 +1,61 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + # "//mediapipe/tasks:internal", + "//visibility:public", +]) + +licenses(["notice"]) + +cc_library( + name = "face_detector_graph", + srcs = ["face_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:clip_vector_size_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:detection_projection_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc new file mode 100644 index 000000000..6b60621a6 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc @@ -0,0 +1,208 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::Tensor; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::vision::face_detector::proto:: + FaceDetectorGraphOptions; + +namespace { +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kDetectionsTag[] = "DETECTIONS"; + +void ConfigureSsdAnchorsCalculator( + mediapipe::SsdAnchorsCalculatorOptions* options) { + // TODO config SSD anchors parameters from metadata. + options->set_num_layers(1); + options->set_min_scale(0.1484375); + options->set_max_scale(0.75); + options->set_input_size_height(192); + options->set_input_size_width(192); + options->set_anchor_offset_x(0.5); + options->set_anchor_offset_y(0.5); + options->add_strides(4); + options->add_aspect_ratios(1.0); + options->set_fixed_anchor_size(true); + options->set_interpolated_scale_aspect_ratio(0.0); +} + +void ConfigureTensorsToDetectionsCalculator( + const FaceDetectorGraphOptions& tasks_options, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + // TODO use metadata to configure these fields. + options->set_num_classes(1); + options->set_num_boxes(2304); + options->set_num_coords(16); + options->set_box_coord_offset(0); + options->set_keypoint_coord_offset(4); + options->set_num_keypoints(6); + options->set_num_values_per_keypoint(2); + options->set_sigmoid_score(true); + options->set_score_clipping_thresh(100.0); + options->set_reverse_output_order(true); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); + options->set_x_scale(192.0); + options->set_y_scale(192.0); + options->set_w_scale(192.0); + options->set_h_scale(192.0); +} + +void ConfigureNonMaxSuppressionCalculator( + const FaceDetectorGraphOptions& tasks_options, + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold( + tasks_options.min_suppression_threshold()); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED); +} + +} // namespace + +class FaceDetectorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN(const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN(auto face_detections, + BuildFaceDetectionSubgraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); + face_detections >> graph[Output>(kDetectionsTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr>> BuildFaceDetectionSubgraph( + const FaceDetectorGraphOptions& subgraph_options, + const core::ModelResources& model_resources, Source image_in, + Source norm_rect_in, Graph& graph) { + // Image preprocessing subgraph to convert image to tensor for the tflite + // model. + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + model_resources, use_gpu, + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + image_in >> preprocessing.In("IMAGE"); + norm_rect_in >> preprocessing.In("NORM_RECT"); + auto preprocessed_tensors = preprocessing.Out("TENSORS"); + auto matrix = preprocessing.Out("MATRIX"); + + // Face detection model inferece. + auto& inference = AddInference( + model_resources, subgraph_options.base_options().acceleration(), graph); + preprocessed_tensors >> inference.In("TENSORS"); + auto model_output_tensors = + inference.Out("TENSORS").Cast>(); + + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ConfigureSsdAnchorsCalculator( + &ssd_anchor.GetOptions()); + auto anchors = ssd_anchor.SideOut(""); + + // Converts output tensors to Detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + ConfigureTensorsToDetectionsCalculator( + subgraph_options, + &tensors_to_detections + .GetOptions()); + model_output_tensors >> tensors_to_detections.In("TENSORS"); + anchors >> tensors_to_detections.SideIn("ANCHORS"); + auto detections = tensors_to_detections.Out("DETECTIONS"); + + // Non maximum suppression removes redundant face detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + ConfigureNonMaxSuppressionCalculator( + subgraph_options, + &non_maximum_suppression + .GetOptions()); + detections >> non_maximum_suppression.In(""); + auto nms_detections = non_maximum_suppression.Out(""); + + // Projects detections back into the input image coordinates system. + auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); + nms_detections >> detection_projection.In("DETECTIONS"); + matrix >> detection_projection.In("PROJECTION_MATRIX"); + auto face_detections = + detection_projection[Output>("DETECTIONS")]; + + return {face_detections}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph); + +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc new file mode 100644 index 000000000..fc1f49f13 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::mediapipe::tasks::vision::face_detector::proto:: + FaceDetectorGraphOptions; +using ::testing::EqualsProto; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kFullRangeBlazeFaceModel[] = "face_detection_full_range.tflite"; +constexpr char kFullRangeSparseBlazeFaceModel[] = + "face_detection_full_range_sparse.tflite"; +constexpr char kPortraitImage[] = "portrait.jpg"; +constexpr char kPortraitExpectedDetection[] = + "portrait_expected_detection.pbtxt"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kDetectionsName[] = "detections"; + +constexpr float kFaceDetectionMaxDiff = 0.01; + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& face_detector_graph = + graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + options->set_min_detection_confidence(0.6); + options->set_min_suppression_threshold(0.3); + face_detector_graph.GetOptions().Swap( + options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + face_detector_graph.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + face_detector_graph.In(kNormRectTag); + + face_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >> + graph[Output>(kDetectionsTag)]; + + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); +} + +Detection GetExpectedFaceDetectionResult(absl::string_view file_name) { + Detection detection; + CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) + << "Expected face detection result does not exist."; + return detection; +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of face landmark detection model. + std::string face_detection_model_name; + // The filename of test image. + std::string test_image_name; + // Expected face detection results. + std::vector expected_result; +}; + +class FaceDetectorGraphTest : public testing::TestWithParam {}; + +TEST_P(FaceDetectorGraphTest, Succeed) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name)); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + MP_ASSERT_OK(output_packets); + const std::vector& face_detections = + (*output_packets)[kDetectionsName].Get>(); + EXPECT_THAT(face_detections, Pointwise(Approximately(Partially(EqualsProto()), + kFaceDetectionMaxDiff), + GetParam().expected_result)); +} + +INSTANTIATE_TEST_SUITE_P( + FaceDetectorGraphTest, FaceDetectorGraphTest, + Values(TestParams{.test_name = "FullRange", + .face_detection_model_name = kFullRangeBlazeFaceModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedFaceDetectionResult( + kPortraitExpectedDetection)}}, + TestParams{ + .test_name = "FullRangeSparse", + .face_detection_model_name = kFullRangeSparseBlazeFaceModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedFaceDetectionResult( + kPortraitExpectedDetection)}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_detector/proto/BUILD b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD new file mode 100644 index 000000000..ca9a6f8c4 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "face_detector_graph_options_proto", + srcs = ["face_detector_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto new file mode 100644 index 000000000..a58338288 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto @@ -0,0 +1,42 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.face_detector.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.facedetector.proto"; +option java_outer_classname = "FaceDetectorGraphOptionsProto"; + +message FaceDetectorGraphOptions { + extend mediapipe.CalculatorOptions { + optional FaceDetectorGraphOptions ext = 502141897; + } + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered + // successfully detecting a face in the image. + optional float min_detection_confidence = 2 [default = 0.5]; + + // IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered + // duplicate detetions. + optional float min_suppression_threshold = 3 [default = 0.5]; +} diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 607245700..09f830aba 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -37,6 +37,8 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "face_detection_full_range.tflite", + "face_detection_full_range_sparse.tflite", "fist.jpg", "fist.png", "hand_landmark_full.tflite", @@ -58,6 +60,7 @@ mediapipe_files(srcs = [ "palm_detection_full.tflite", "pointing_up.jpg", "pointing_up_rotated.jpg", + "portrait.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -79,6 +82,7 @@ exports_files( "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", "gesture_recognizer.task", + "portrait_expected_detection.pbtxt", ], ) @@ -106,6 +110,7 @@ filegroup( "multi_objects_rotated.jpg", "pointing_up.jpg", "pointing_up_rotated.jpg", + "portrait.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -129,6 +134,8 @@ filegroup( "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "face_detection_full_range.tflite", + "face_detection_full_range_sparse.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", @@ -161,6 +168,7 @@ filegroup( "hand_detector_result_two_hands.pbtxt", "pointing_up_landmarks.pbtxt", "pointing_up_rotated_landmarks.pbtxt", + "portrait_expected_detection.pbtxt", "thumb_up_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt", "victory_landmarks.pbtxt", diff --git a/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt b/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt new file mode 100644 index 000000000..775f4479b --- /dev/null +++ b/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt @@ -0,0 +1,35 @@ +# proto-file: mediapipe/framework/formats/detection.proto +# proto-message: Detection +location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.35494408 + ymin: 0.1059662 + width: 0.28768203 + height: 0.23037356 + } + relative_keypoints { + x: 0.44416338 + y: 0.17643969 + } + relative_keypoints { + x: 0.55514044 + y: 0.17731678 + } + relative_keypoints { + x: 0.5046702 + y: 0.2265771 + } + relative_keypoints { + x: 0.50227845 + y: 0.2719954 + } + relative_keypoints { + x: 0.37245658 + y: 0.20143759 + } + relative_keypoints { + x: 0.6084143 + y: 0.20409837 + } +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 5adfbdfc6..1d9239c83 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -240,14 +240,14 @@ def external_files(): http_file( name = "com_google_mediapipe_face_detection_full_range_sparse_tflite", - sha256 = "671dd2f9ed11a78436fc21cc42357a803dfc6f73e9fb86541be942d5716c2dce", - urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1661875739104017"], + sha256 = "2c3728e6da56f21e21a320433396fb06d40d9088f2247c05e5635a688d45dfe1", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1674261618323821"], ) http_file( name = "com_google_mediapipe_face_detection_full_range_tflite", - sha256 = "99bf9494d84f50acc6617d89873f71bf6635a841ea699c17cb3377f9507cfec3", - urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1661875742733283"], + sha256 = "3698b18f063835bc609069ef052228fbe86d9c9a6dc8dcb7c7c2d69aed2b181b", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1674261620964007"], ) http_file( @@ -712,6 +712,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"], ) + http_file( + name = "com_google_mediapipe_portrait_expected_detection_pbtxt", + sha256 = "bb54e08e87844ef14bb185d5cb808908eb6011bfa6db48bd22d9650f6fda338b", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_detection.pbtxt?generation=1674261627835475"], + ) + + http_file( + name = "com_google_mediapipe_portrait_jpg", + sha256 = "a6f11efaa834706db23f275b6115058fa87fc7f14362681e6abe14e82749de3e", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait.jpg?generation=1674261630039907"], + ) + http_file( name = "com_google_mediapipe_pose_detection_tflite", sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", From 2465e47b01e6883c22c7a7047e9dd087e93e7615 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 23 Jan 2023 16:41:32 -0800 Subject: [PATCH 420/469] Stream/SidePacket == and != operators PiperOrigin-RevId: 504114182 --- mediapipe/framework/api2/builder.h | 13 +++++++ mediapipe/framework/api2/builder_test.cc | 46 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 2a98c4166..da09acc83 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -206,6 +206,16 @@ class SourceImpl { return ConnectTo(dest); } + template + bool operator==(const SourceImpl& other) { + return base_ == other.base_; + } + + template + bool operator!=(const SourceImpl& other) { + return !(*this == other); + } + Src& SetName(std::string name) { base_->name_ = std::move(name); return *this; @@ -218,6 +228,9 @@ class SourceImpl { } private: + template + friend class SourceImpl; + // Never null. SourceBase* base_; }; diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 08f4f0ca1..194f1b8ff 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -494,5 +494,51 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } +TEST(BuilderTest, TestStreamEqualsNotEqualsOperators) { + Graph graph; + Stream input0 = graph.In(0); + EXPECT_TRUE(input0 == input0); + EXPECT_FALSE(input0 != input0); + + EXPECT_TRUE(input0 == input0.Cast()); + EXPECT_FALSE(input0.Cast() != input0); + + EXPECT_TRUE(input0.Cast() == input0.Cast()); + EXPECT_FALSE(input0.Cast() != input0.Cast()); + + Stream input1 = graph.In(1); + EXPECT_FALSE(input0 == input1); + EXPECT_TRUE(input0 != input1); + + input1 = input0; + EXPECT_TRUE(input0 == input1); + EXPECT_FALSE(input0 != input1); + EXPECT_TRUE(input0.Cast() == input1.Cast()); + EXPECT_FALSE(input0.Cast() != input1.Cast()); +} + +TEST(BuilderTest, TestSidePacketEqualsNotEqualsOperators) { + Graph graph; + SidePacket side_input0 = graph.SideIn(0); + EXPECT_TRUE(side_input0 == side_input0); + EXPECT_FALSE(side_input0 != side_input0); + + EXPECT_TRUE(side_input0 == side_input0.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input0); + + EXPECT_TRUE(side_input0.Cast() == side_input0.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input0.Cast()); + + SidePacket side_input1 = graph.SideIn(1); + EXPECT_FALSE(side_input0 == side_input1); + EXPECT_TRUE(side_input0 != side_input1); + + side_input1 = side_input0; + EXPECT_TRUE(side_input0 == side_input1); + EXPECT_FALSE(side_input0 != side_input1); + EXPECT_TRUE(side_input0.Cast() == side_input1.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input1.Cast()); +} + } // namespace } // namespace mediapipe::api2::builder From 4e135ccdb9273c2e465b701130529bb3d4c77172 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 24 Jan 2023 10:36:36 -0800 Subject: [PATCH 421/469] Internal Model Maker change. PiperOrigin-RevId: 504315641 --- mediapipe/model_maker/python/text/text_classifier/BUILD | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 43f2b6c75..ac5b04f20 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -140,7 +140,11 @@ py_test( "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", "//mediapipe/model_maker/python/text/text_classifier/testdata", ], - tags = ["requires-net:external"], + tags = [ + "notsan", + "requires-mem:16g", + "requires-net:external", + ], deps = [ ":text_classifier_import", "//mediapipe/tasks/python/test:test_utils", From 9cde57d8303437576d131d2a8def7670fac90064 Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Tue, 24 Jan 2023 12:11:35 -0800 Subject: [PATCH 422/469] Internal change PiperOrigin-RevId: 504341832 --- mediapipe/framework/tool/template_parser.cc | 26 ++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index cf23f3443..6c7237f8e 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -330,7 +330,7 @@ class TemplateParser::Parser::ParserImpl { return suc && LookingAtType(io::Tokenizer::TYPE_END); } - void ReportError(int line, int col, const std::string& message) { + void ReportError(int line, int col, absl::string_view message) { had_errors_ = true; if (error_collector_ == NULL) { if (line >= 0) { @@ -342,11 +342,11 @@ class TemplateParser::Parser::ParserImpl { << root_message_type_->full_name() << ": " << message; } } else { - error_collector_->AddError(line, col, message); + error_collector_->AddError(line, col, std::string(message)); } } - void ReportWarning(int line, int col, const std::string& message) { + void ReportWarning(int line, int col, absl::string_view message) { if (error_collector_ == NULL) { if (line >= 0) { LOG(WARNING) << "Warning parsing text-format " @@ -357,21 +357,21 @@ class TemplateParser::Parser::ParserImpl { << root_message_type_->full_name() << ": " << message; } } else { - error_collector_->AddWarning(line, col, message); + error_collector_->AddWarning(line, col, std::string(message)); } } protected: // Reports an error with the given message with information indicating // the position (as derived from the current token). - void ReportError(const std::string& message) { + void ReportError(absl::string_view message) { ReportError(tokenizer_.current().line, tokenizer_.current().column, message); } // Reports a warning with the given message with information indicating // the position (as derived from the current token). - void ReportWarning(const std::string& message) { + void ReportWarning(absl::string_view message) { ReportWarning(tokenizer_.current().line, tokenizer_.current().column, message); } @@ -379,7 +379,7 @@ class TemplateParser::Parser::ParserImpl { // Consumes the specified message with the given starting delimiter. // This method checks to see that the end delimiter at the conclusion of // the consumption matches the starting delimiter passed in here. - bool ConsumeMessage(Message* message, const std::string delimiter) { + bool ConsumeMessage(Message* message, absl::string_view delimiter) { while (!LookingAt(">") && !LookingAt("}")) { if (LookingAt("%")) { DO(ConsumeFieldTemplate(message)); @@ -407,7 +407,7 @@ class TemplateParser::Parser::ParserImpl { #ifndef PROTO2_OPENSOURCE // Consumes a string value and parses it as a packed repeated field into // the given field of the given message. - bool ConsumePackedFieldAsString(const std::string& field_name, + bool ConsumePackedFieldAsString(absl::string_view field_name, const FieldDescriptor* field, Message* message) { std::string packed; @@ -431,8 +431,8 @@ class TemplateParser::Parser::ParserImpl { io::ArrayInputStream array_input(tagged.data(), tagged.size()); io::CodedInputStream coded_input(&array_input); if (!message->MergePartialFromCodedStream(&coded_input)) { - ReportError("Could not parse packed field \"" + field_name + - "\" as wire-encoded string."); + ReportError(absl::StrCat("Could not parse packed field \"", field_name, + "\" as wire-encoded string.")); return false; } @@ -1219,12 +1219,12 @@ class TemplateParser::Parser::ParserImpl { // Consumes a token and confirms that it matches that specified in the // value parameter. Returns false if the token found does not match that // which was specified. - bool Consume(const std::string& value) { + bool Consume(absl::string_view value) { const std::string& current_value = tokenizer_.current().text; if (current_value != value) { - ReportError("Expected \"" + value + "\", found \"" + current_value + - "\"."); + ReportError(absl::StrCat("Expected \"", value, "\", found \"", + current_value, "\".")); return false; } From ce9fec806cd47a4c78cf2362274cf23e0e7341c7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 24 Jan 2023 12:11:53 -0800 Subject: [PATCH 423/469] Internal change PiperOrigin-RevId: 504341886 --- mediapipe/gpu/BUILD | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9074daf61..55e9c98c2 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -902,6 +902,17 @@ cc_library( alwayslink = 1, ) +### Simple calculators + +mediapipe_proto_library( + name = "gl_animation_overlay_calculator_proto", + srcs = ["gl_animation_overlay_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + proto_library( name = "gl_scaler_calculator_proto", srcs = ["gl_scaler_calculator.proto"], From 679dbb3fd83717949ad6c5eccb05dfd5481e5e65 Mon Sep 17 00:00:00 2001 From: Yuqi Li Date: Tue, 24 Jan 2023 14:44:19 -0800 Subject: [PATCH 424/469] nit: update the metadata_schema.fbs file path. PiperOrigin-RevId: 504380873 --- mediapipe/tasks/python/metadata/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py index 2327ebbdf..6afb5a3fa 100644 --- a/mediapipe/tasks/python/metadata/metadata.py +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -106,7 +106,7 @@ class MetadataPopulator(object): The metadata file (or buffer) should be generated based on the metadata schema: - third_party/tensorflow/lite/schema/metadata_schema.fbs + mediapipe/tasks/metadata/metadata_schema.fbs Example usage: Populate matadata and label file into an image classifier model. From 5dc81c4c27baf3e75072447622d49c09d46f6ac6 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 24 Jan 2023 15:52:38 -0800 Subject: [PATCH 425/469] Remove unused import on strings.h PiperOrigin-RevId: 504397437 --- mediapipe/tasks/cc/components/containers/detection_result.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc index 43c8ca0f5..38126f917 100644 --- a/mediapipe/tasks/cc/components/containers/detection_result.cc +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -15,8 +15,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/detection_result.h" -#include - #include #include #include From afb018293514ff2707f19ca3bd955aef90ca16b4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 24 Jan 2023 23:14:13 -0800 Subject: [PATCH 426/469] Internal model maker change. PiperOrigin-RevId: 504472342 --- .../python/core/tasks/classifier.py | 3 +- .../gesture_recognizer/gesture_recognizer.py | 85 ++++++++----------- 2 files changed, 35 insertions(+), 53 deletions(-) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 0908dddf5..abcfff835 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -37,7 +37,7 @@ class Classifier(custom_model.CustomModel): label_names: A list of label names for the classes. shuffle: Whether the dataset should be shuffled. """ - super(Classifier, self).__init__(model_spec, shuffle) + super().__init__(model_spec, shuffle) self._label_names = label_names self._num_classes = len(label_names) self._model: tf.keras.Model = None @@ -48,7 +48,6 @@ class Classifier(custom_model.CustomModel): self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None - # TODO: Integrate this into GestureRecognizer. def _train_model(self, train_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset, diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index 556d2fcd7..b27f7161f 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -53,6 +53,10 @@ class GestureRecognizer(classifier.Classifier): model_spec=None, label_names=label_names, shuffle=hparams.shuffle) self._model_options = model_options self._hparams = hparams + self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma) + self._metric_function = 'categorical_accuracy' + self._optimizer = 'adam' + self._callbacks = self._get_callbacks() self._history = None self.embedding_size = _EMBEDDING_SIZE @@ -71,7 +75,7 @@ class GestureRecognizer(classifier.Classifier): Args: train_data: Training data. - validation_data: Validation data. If None, skips validation process. + validation_data: Validation data. options: options for creating and training gesture recognizer model. Returns: @@ -87,49 +91,39 @@ class GestureRecognizer(classifier.Classifier): label_names=train_data.label_names, model_options=options.model_options, hparams=options.hparams) - - gesture_recognizer._create_model() - - train_dataset = train_data.gen_tf_dataset( - batch_size=options.hparams.batch_size, - is_training=True, - shuffle=options.hparams.shuffle) - options.hparams.steps_per_epoch = model_util.get_steps_per_epoch( - steps_per_epoch=options.hparams.steps_per_epoch, - batch_size=options.hparams.batch_size, - train_data=train_data) - train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch) - - validation_dataset = validation_data.gen_tf_dataset( - batch_size=options.hparams.batch_size, is_training=False) - - tf.compat.v1.logging.info('Training the gesture recognizer model...') - gesture_recognizer._train( - train_data=train_dataset, validation_data=validation_dataset) - + gesture_recognizer._create_and_train_model(train_data, validation_data) return gesture_recognizer - def _train(self, train_data: tf.data.Dataset, - validation_data: tf.data.Dataset): - """Trains the model with input train_data. - - The training results are recorded by a self.History object returned by - tf.keras.Model.fit(). + def _create_and_train_model( + self, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + ): + """Creates and trains the model. Args: train_data: Training data. validation_data: Validation data. """ + self._create_model() + self._train_model( + train_data=train_data, + validation_data=validation_data, + checkpoint_path=self._get_checkpoint_path(), + ) + + def _get_callbacks(self) -> List[tf.keras.callbacks.Callback]: + """Gets the list of callbacks to use in model training.""" hparams = self._hparams scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch) scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler) job_dir = hparams.export_dir - checkpoint_path = os.path.join(job_dir, 'epoch_models') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(checkpoint_path, 'model-{epoch:04d}'), - save_weights_only=True) + os.path.join(self._get_checkpoint_path(), 'model-{epoch:04d}'), + save_weights_only=True, + ) best_model_path = os.path.join(job_dir, 'best_model_weights') best_model_callback = tf.keras.callbacks.ModelCheckpoint( @@ -141,27 +135,15 @@ class GestureRecognizer(classifier.Classifier): tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=os.path.join(job_dir, 'logs')) + return [ + checkpoint_callback, + best_model_callback, + scheduler_callback, + tensorboard_callback, + ] - self._model.compile( - optimizer='adam', - loss=loss_functions.FocalLoss(gamma=self._hparams.gamma), - metrics=['categorical_accuracy']) - - latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) - if latest_checkpoint: - print(f'Resuming from {latest_checkpoint}') - self._model.load_weights(latest_checkpoint) - - self._history = self._model.fit( - x=train_data, - epochs=hparams.epochs, - validation_data=validation_data, - validation_freq=1, - callbacks=[ - checkpoint_callback, best_model_callback, scheduler_callback, - tensorboard_callback - ], - ) + def _get_checkpoint_path(self) -> str: + return os.path.join(self._hparams.export_dir, 'epoch_models') def _create_model(self): """Creates the hand gesture recognizer model. @@ -172,7 +154,8 @@ class GestureRecognizer(classifier.Classifier): shape=[self.embedding_size], batch_size=None, dtype=tf.float32, - name='hand_embedding') + name='hand_embedding', + ) x = inputs dropout_rate = self._model_options.dropout_rate for i, width in enumerate(self._model_options.layer_widths): From 7d624027687894705a68835a2cee239101f1fb94 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 25 Jan 2023 20:16:54 +0530 Subject: [PATCH 427/469] Added MPPEmbedding --- .../tasks/ios/components/containers/BUILD | 7 ++ .../containers/sources/MPPEmbedding.h | 70 +++++++++++++++++++ .../containers/sources/MPPEmbedding.m | 34 +++++++++ 3 files changed, 111 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index fb23160b8..22bbe7731 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -28,3 +28,10 @@ objc_library( hdrs = ["sources/MPPClassificationResult.h"], deps = [":MPPCategory"], ) + +objc_library( + name = "MPPEmbedding", + srcs = ["sources/MPPEmbedding.m"], + hdrs = ["sources/MPPEmbedding.h"], +) + diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h new file mode 100644 index 000000000..a9db8e579 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h @@ -0,0 +1,70 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Represents the embedding for a given embedder head. Typically used in embedding tasks. + * + * One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will contain data, based on + * whether or not the embedder was configured to perform scala quantization. + */ +NS_SWIFT_NAME(Embedding) +@interface MPPEmbedding : NSObject + +/** + * @brief The Floating-point embedding. + * Empty if the embedder was configured to perform scalar quantization. + */ +@property(nonatomic, readonly, nullable) float *floatEmbedding; + +/** + * @brief The Quantized embedding. + * Empty if the embedder was not configured to perform scalar quantization. + */ +@property(nonatomic, readonly, nullable) char *quantizedEmbedding; + +/** The index of the embedder head these entries refer to. This is useful for multi-head models. */ +@property(nonatomic, readonly) NSInteger headIndex; + +/** The optional name of the embedder head, which is the corresponding tensor metadata name. */ +@property(nonatomic, readonly, nullable) NSString *headName; + +/** + * Initializes a new `MPPEmbedding` with the given float embedding, quantized embedding, head index + * and head name. + * + * @param floatEmbedding The optional Floating-point embedding. + * @param quantizedEmbedding The optional Quantized embedding. + * @param headIndex The index of the embedder head. + * @param headName The optional name of the embedder head. + * + * @return An instance of `MPPEmbedding` initialized with the given float embedding, quantized + * embedding, head index and head name. + * + */ +- (instancetype)initWithFloatEmbedding:(nullable float *)floatEmbedding + quantizedEmbedding:(nullable char *)quantizedEmbedding + headIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m new file mode 100644 index 000000000..642853ef1 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m @@ -0,0 +1,34 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h" + +@implementation MPPEmbedding + +- (instancetype)initWithFloatEmbedding:(nullable float *)floatEmbedding + quantizedEmbedding:(nullable char *)quantizedEmbedding + headIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName { + // TODO: Should null check for embeddings be done here ? + self = [super init]; + if (self) { + _headIndex = headIndex; + _headName = headName; + _floatEmbedding = floatEmbedding; + _quantizedEmbedding = quantizedEmbedding; + } + return self; +} + +@end From db5ee6689f195e2e5eeef4e389d53f12cf0b5e3b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 25 Jan 2023 20:17:04 +0530 Subject: [PATCH 428/469] Added MPPEmbeddingResult --- .../tasks/ios/components/containers/BUILD | 6 ++ .../containers/sources/MPPEmbeddingResult.h | 59 +++++++++++++++++++ .../containers/sources/MPPEmbeddingResult.m | 30 ++++++++++ 3 files changed, 95 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index 22bbe7731..ee54bb712 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -35,3 +35,9 @@ objc_library( hdrs = ["sources/MPPEmbedding.h"], ) +objc_library( + name = "MPPEmbeddingResult", + srcs = ["sources/MPPEmbeddingResult.m"], + hdrs = ["sources/MPPEmbeddingResult.h"], + deps = [":MPPEmbedding"], +) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h new file mode 100644 index 000000000..3d5d48b9b --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h @@ -0,0 +1,59 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the embedding results of a model. Typically used as a result for embedding tasks. */ +NS_SWIFT_NAME(EmbeddingResult) +@interface MPPEmbeddingResult : NSObject + +/** + * An Array of `MPPEmbedding` objects containing the embedding results for each head of the model. + */ +@property(nonatomic, readonly) NSArray *embeddings; + +/** + * @brief The optional timestamp (in milliseconds) of the start of the chunk of data corresponding + * to these results. + * This is only used for embedding extraction on time series (e.g. audio embedder). In these use + * cases, the amount of data to process might exceed the maximum size that the model can process: to + * solve this, the input data is split into multiple chunks starting at different timestamps. + */ +@property(nonatomic, readonly) NSInteger timestampMs; + +/** + * Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in + * milliseconds). + * + * @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each + * head of the model. + * @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. Pass `0` if timestamp is absent. + * + * @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and + * timestampMs. + */ +- (instancetype)initWithEmbeddings:(NSArray *)embeddings + timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m new file mode 100644 index 000000000..56dd30fdd --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m @@ -0,0 +1,30 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h" + +@implementation MPPEmbeddingResult + +- (instancetype)initWithEmbeddings:(NSArray *)embeddings + timestampMs:(NSInteger)timestampMs { + self = [super init]; + if (self) { + _embeddings = embeddings; + _timestampMs = timestampMs; + } + + return self; +} + +@end From 60e72bf1655689a67a2cc6ef0aa2a008f8401be1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 25 Jan 2023 20:19:27 +0530 Subject: [PATCH 429/469] Added MPPTextEmbedderOptions --- mediapipe/tasks/ios/text/text_embedder/BUILD | 27 +++++++++++ .../sources/MPPTextEmbedderOptions.h | 47 +++++++++++++++++++ .../sources/MPPTextEmbedderOptions.m | 28 +++++++++++ 3 files changed, 102 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_embedder/BUILD create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.m diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD new file mode 100644 index 000000000..65cbde093 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -0,0 +1,27 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextEmbedderOptions", + srcs = ["sources/MPPTextEmbedderOptions.m"], + hdrs = ["sources/MPPTextEmbedderOptions.h"], + deps = ["//mediapipe/tasks/ios/core:MPPTaskOptions"], +) + + ], +) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h new file mode 100644 index 000000000..ce9fc8b20 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h @@ -0,0 +1,47 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options for setting up a `MPPTextEmbedder`. + */ +NS_SWIFT_NAME(TextEmbedderptions) +@interface MPPTextEmbedderOptions : MPPTaskOptions + +/** + * @brief Sets whether L2 normalization should be performed on the returned embeddings. + * Use this option only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. + * In most cases, this is already the case and L2 norm is thus achieved through TF Lite inference. + * + * NO by default. + */ +@property(nonatomic) BOOL l2Normalize; + +/** + * @brief Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed to + * have value in [-1.0, 1.0]. Use the `l2Normalize` property if this is not the case. + * + * NO by default. + */ +@property(nonatomic) BOOL quantize; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.m b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.m new file mode 100644 index 000000000..6da3659f7 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h" + +@implementation MPPTextEmbedderOptions + +- (id)copyWithZone:(NSZone *)zone { + MPPTextEmbedderOptions *textEmbedderOptions = [super copyWithZone:zone]; + + textEmbedderOptions.l2Normalize = self.l2Normalize; + textEmbedderOptions.quantize = self.quantize; + + return textEmbedderOptions; +} + +@end From 168ea0a9ea33affea3982e89ed75fb4307ffc8c4 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 25 Jan 2023 20:19:40 +0530 Subject: [PATCH 430/469] Added MPPTextEmbedderResult --- mediapipe/tasks/ios/text/text_embedder/BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD index 65cbde093..143f0a587 100644 --- a/mediapipe/tasks/ios/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -23,5 +23,12 @@ objc_library( deps = ["//mediapipe/tasks/ios/core:MPPTaskOptions"], ) +objc_library( + name = "MPPTextEmbedderResult", + srcs = ["sources/MPPTextEmbedderResult.m"], + hdrs = ["sources/MPPTextEmbedderResult.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPEmbeddingResult", + "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) From d01f75a295f6779d56f86c5e1a15c3f2360867aa Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 25 Jan 2023 20:19:57 +0530 Subject: [PATCH 431/469] Added iOS text embedder result files --- .../sources/MPPTextEmbedderResult.h | 48 +++++++++++++++++++ .../sources/MPPTextEmbedderResult.m | 28 +++++++++++ 2 files changed, 76 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h new file mode 100644 index 000000000..e4697dcef --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h @@ -0,0 +1,48 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the embedding results generated by `MPPTextEmbedder`. **/ +NS_SWIFT_NAME(TextEmbedderResult) +@interface MPPTextEmbedderResult : MPPTaskResult + +/** The `MPPEmbedderResult` instance containing one embedding per embedder head. **/ +@property(nonatomic, readonly) MPPEmbeddingResult *embeddingResult; + +/** + * Initializes a new `MPPTextEmbedderResult` with the given `MPPEmbeddingResult` and + * timestamp (in milliseconds). + * + * @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results + * per classifier head. + * @param timestampMs The timestamp for this result. + * + * @return An instance of `MPPTextEmbedderResult` initialized with the given + * `MPPEmbeddingResult` and timestamp (in milliseconds). + */ +- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult + timestampMs:(NSInteger)timestampMs; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m new file mode 100644 index 000000000..5483e3c3f --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h" + +@implementation MPPTextEmbedderResult + +- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; + if (self) { + _embeddingResult = embeddingResult; + } + return self; +} + +@end From 61f7739ff6ffb953f7243b4a09efde03897aa7a1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 25 Jan 2023 20:20:22 +0530 Subject: [PATCH 432/469] Updated documentation --- .../ios/text/text_classifier/sources/MPPTextClassifierOptions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h index 4726203d3..55ab020f7 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -19,7 +19,7 @@ NS_ASSUME_NONNULL_BEGIN /** - * Options for setting up a `MPPTextClassifierOptions`. + * Options for setting up a `MPPTextClassifier`. */ NS_SWIFT_NAME(TextClassifierOptions) @interface MPPTextClassifierOptions : MPPTaskOptions From 1538740dcbbfa87f8d43c34ad1bccaa95b9f6efa Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 25 Jan 2023 10:31:10 -0800 Subject: [PATCH 433/469] Formatting fix PiperOrigin-RevId: 504599712 --- mediapipe/tasks/web/components/containers/embedding_result.d.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/web/components/containers/embedding_result.d.ts b/mediapipe/tasks/web/components/containers/embedding_result.d.ts index 3779abd96..43d14d30e 100644 --- a/mediapipe/tasks/web/components/containers/embedding_result.d.ts +++ b/mediapipe/tasks/web/components/containers/embedding_result.d.ts @@ -33,6 +33,7 @@ export declare interface Embedding { * perform scalar quantization. */ quantizedEmbedding?: Uint8Array; + /** * The index of the classifier head these categories refer to. This is * useful for multi-head models. From ff0ccfc20952e2575d0de51ffaf86fbdb6215865 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 25 Jan 2023 15:29:17 -0800 Subject: [PATCH 434/469] Internal change PiperOrigin-RevId: 504677663 --- docs/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/BUILD b/docs/BUILD index 8e85dbf86..80d3ab550 100644 --- a/docs/BUILD +++ b/docs/BUILD @@ -4,12 +4,10 @@ py_binary( name = "build_py_api_docs", srcs = ["build_py_api_docs.py"], deps = [ - "//mediapipe", "//third_party/py/absl:app", "//third_party/py/absl/flags", - "//third_party/py/tensorflow_docs", + "//third_party/py/mediapipe", "//third_party/py/tensorflow_docs/api_generator:generate_lib", - "//third_party/py/tensorflow_docs/api_generator:public_api", ], ) From be546d22fcfa30bba2dd76d4afbf4b036840d31d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 25 Jan 2023 17:11:41 -0800 Subject: [PATCH 435/469] Update test to reflect the recommended graph construction style: First, graph inputs and their names: - Makes it clear what inputs graph has - Indirectly demands for type specification e.g. Stream a = graph.In(0); vs Stream a = graph.In(0).Cast(); Then graph nodes - Nodes are added and used as they needed - One node is not mixed in other nodes, only its outputs - Indirectly demands for type specification e.g. Stream a = node.Out(0); vs Stream a = node.Out(0).Cast(); Then graph outputs - Makes it clear what outputs graph has The recommended structure keep C++ graph similar to pbtxt representation. PiperOrigin-RevId: 504701023 --- mediapipe/framework/api2/builder_test.cc | 170 ++++++++++++++++------- 1 file changed, 121 insertions(+), 49 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 194f1b8ff..363971689 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -26,12 +26,21 @@ using ::mediapipe::api2::test::FooBar1; TEST(BuilderTest, BuildGraph) { Graph graph; + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + SidePacket side = graph.SideIn("SIDE").SetName("side"); + auto& foo = graph.AddNode("Foo"); + base >> foo.In("BASE"); + side >> foo.SideIn("SIDE"); + Stream foo_out = foo.Out("OUT"); + auto& bar = graph.AddNode("Bar"); - graph.In("IN").SetName("base") >> foo.In("BASE"); - graph.SideIn("SIDE").SetName("side") >> foo.SideIn("SIDE"); - foo.Out("OUT") >> bar.In("IN"); - bar.Out("OUT").SetName("out") >> graph.Out("OUT"); + foo_out >> bar.In("IN"); + Stream bar_out = bar.Out("OUT"); + + // Graph outputs. + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -87,6 +96,7 @@ TEST(BuilderTest, CopyableStream) { TEST(BuilderTest, BuildGraphWithFunctions) { Graph graph; + // Graph inputs. Stream base = graph.In("IN").SetName("base").Cast(); SidePacket side = graph.SideIn("SIDE").SetName("side").Cast(); @@ -105,6 +115,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) { }; Stream bar_out = bar_fn(foo_out, graph); + // Graph outputs. bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = @@ -130,12 +141,21 @@ TEST(BuilderTest, BuildGraphWithFunctions) { template void BuildGraphTypedTest() { Graph graph; + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + SidePacket side = graph.SideIn("SIDE").SetName("side"); + auto& foo = graph.AddNode(); + base >> foo.In(MPP_TAG("BASE")); + side >> foo.SideIn(MPP_TAG("BIAS")); + Stream foo_out = foo.Out(MPP_TAG("OUT")); + auto& bar = graph.AddNode(); - graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); - graph.SideIn("SIDE").SetName("side") >> foo.SideIn(MPP_TAG("BIAS")); - foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN")); - bar.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT"); + foo_out >> bar.In(MPP_TAG("IN")); + Stream bar_out = bar.Out(MPP_TAG("OUT")); + + // Graph outputs. + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie( @@ -165,12 +185,20 @@ TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } TEST(BuilderTest, FanOut) { Graph graph; + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + auto& foo = graph.AddNode("Foo"); + base >> foo.In("BASE"); + Stream foo_out = foo.Out("OUT"); + auto& adder = graph.AddNode("FloatAdder"); - graph.In("IN").SetName("base") >> foo.In("BASE"); - foo.Out("OUT") >> adder.In("IN")[0]; - foo.Out("OUT") >> adder.In("IN")[1]; - adder.Out("OUT").SetName("out") >> graph.Out("OUT"); + foo_out >> adder.In("IN")[0]; + foo_out >> adder.In("IN")[1]; + Stream out = adder.Out("OUT"); + + // Graph outputs. + out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -193,12 +221,20 @@ TEST(BuilderTest, FanOut) { TEST(BuilderTest, TypedMultiple) { Graph graph; - auto& foo = graph.AddNode(); - auto& adder = graph.AddNode(); - graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); - foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0]; - foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1]; - adder.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT"); + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + + auto& foo = graph.AddNode(); + base >> foo.In(MPP_TAG("BASE")); + Stream foo_out = foo.Out(MPP_TAG("OUT")); + + auto& adder = graph.AddNode(); + foo_out >> adder.In(MPP_TAG("IN"))[0]; + foo_out >> adder.In(MPP_TAG("IN"))[1]; + Stream out = adder.Out(MPP_TAG("OUT")); + + // Graph outputs. + out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -221,13 +257,20 @@ TEST(BuilderTest, TypedMultiple) { TEST(BuilderTest, TypedByPorts) { Graph graph; - auto& foo = graph.AddNode(); - auto& adder = graph.AddNode(); + // Graph inputs. + Stream base = graph.In(FooBar1::kIn).SetName("base"); - graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase]; - foo[Foo::kOut] >> adder[FloatAdder::kIn][0]; - foo[Foo::kOut] >> adder[FloatAdder::kIn][1]; - adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut); + auto& foo = graph.AddNode(); + base >> foo[Foo::kBase]; + Stream foo_out = foo[Foo::kOut]; + + auto& adder = graph.AddNode(); + foo_out >> adder[FloatAdder::kIn][0]; + foo_out >> adder[FloatAdder::kIn][1]; + Stream out = adder[FloatAdder::kOut]; + + // Graph outputs. + out.SetName("out") >> graph.Out(FooBar1::kOut); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -250,9 +293,15 @@ TEST(BuilderTest, TypedByPorts) { TEST(BuilderTest, PacketGenerator) { Graph graph; + // Graph inputs. + SidePacket side_in = graph.SideIn("IN"); + auto& generator = graph.AddPacketGenerator("FloatGenerator"); - graph.SideIn("IN") >> generator.SideIn("IN"); - generator.SideOut("OUT") >> graph.SideOut("OUT"); + side_in >> generator.SideIn("IN"); + SidePacket side_out = generator.SideOut("OUT"); + + // Graph outputs. + side_out >> graph.SideOut("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -269,12 +318,21 @@ TEST(BuilderTest, PacketGenerator) { TEST(BuilderTest, EmptyTag) { Graph graph; + // Graph inputs. + Stream a = graph.In("A").SetName("a"); + Stream c = graph.In("C").SetName("c"); + Stream b = graph.In("B").SetName("b"); + auto& foo = graph.AddNode("Foo"); - graph.In("A").SetName("a") >> foo.In("")[0]; - graph.In("C").SetName("c") >> foo.In("")[2]; - graph.In("B").SetName("b") >> foo.In("")[1]; - foo.Out("")[0].SetName("x") >> graph.Out("ONE"); - foo.Out("")[1].SetName("y") >> graph.Out("TWO"); + a >> foo.In("")[0]; + c >> foo.In("")[2]; + b >> foo.In("")[1]; + Stream x = foo.Out("")[0]; + Stream y = foo.Out("")[1]; + + // Graph outputs. + x.SetName("x") >> graph.Out("ONE"); + y.SetName("y") >> graph.Out("TWO"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -301,10 +359,17 @@ TEST(BuilderTest, StringLikeTags) { constexpr absl::string_view kC = "C"; Graph graph; + // Graph inputs. + Stream a = graph.In(kA).SetName("a"); + Stream b = graph.In(kB).SetName("b"); + auto& foo = graph.AddNode("Foo"); - graph.In(kA).SetName("a") >> foo.In(kA); - graph.In(kB).SetName("b") >> foo.In(kB); - foo.Out(kC).SetName("c") >> graph.Out(kC); + a >> foo.In(kA); + b >> foo.In(kB); + Stream c = foo.Out(kC); + + // Graph outputs. + c.SetName("c") >> graph.Out(kC); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -323,12 +388,21 @@ TEST(BuilderTest, StringLikeTags) { TEST(BuilderTest, GraphIndexes) { Graph graph; + // Graph inputs. + Stream a = graph.In(0).SetName("a"); + Stream c = graph.In(1).SetName("c"); + Stream b = graph.In(2).SetName("b"); + auto& foo = graph.AddNode("Foo"); - graph.In(0).SetName("a") >> foo.In("")[0]; - graph.In(1).SetName("c") >> foo.In("")[2]; - graph.In(2).SetName("b") >> foo.In("")[1]; - foo.Out("")[0].SetName("x") >> graph.Out(1); - foo.Out("")[1].SetName("y") >> graph.Out(0); + a >> foo.In("")[0]; + c >> foo.In("")[2]; + b >> foo.In("")[1]; + Stream x = foo.Out("")[0]; + Stream y = foo.Out("")[1]; + + // Graph outputs. + x.SetName("x") >> graph.Out(1); + y.SetName("y") >> graph.Out(0); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -381,21 +455,20 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - Stream any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; - any_type_output.SetName("any_type_output"); - Stream same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; - same_type_output.SetName("same_type_output"); Stream recursive_same_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; - recursive_same_type_output.SetName("recursive_same_type_output"); Stream same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; - same_int_output.SetName("same_int_output"); Stream recursive_same_int_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; + + any_type_output.SetName("any_type_output"); + same_type_output.SetName("same_type_output"); + recursive_same_type_output.SetName("recursive_same_type_output"); + same_int_output.SetName("same_int_output"); recursive_same_int_type_output.SetName("recursive_same_int_type_output"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie< @@ -424,11 +497,10 @@ TEST(BuilderTest, AnyTypeCanBeCast) { auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; Stream any_type_output = - node[AnyAndSameTypeCalculator::kAnyTypeOutput] - .SetName("any_type_output") - .Cast(); + node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); - any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); + any_type_output.SetName("any_type_output") >> + graph.Out("GRAPH_ANY_OUTPUT").Cast(); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( From 0566e0e7ca1803ba7b98dfad5f236e6ce997a80e Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 25 Jan 2023 17:50:52 -0800 Subject: [PATCH 436/469] Fix the output stream tag of the end loop calculator in the example code. PiperOrigin-RevId: 504708273 --- mediapipe/calculators/core/begin_loop_calculator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/calculators/core/begin_loop_calculator.h b/mediapipe/calculators/core/begin_loop_calculator.h index a9d29e687..6d17f9953 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.h +++ b/mediapipe/calculators/core/begin_loop_calculator.h @@ -49,7 +49,7 @@ namespace mediapipe { // calculator: "EndLoopWithOutputCalculator" // input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts // input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts -// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts +// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts // } // // Input streams tagged with "CLONE" are cloned to the corresponding output From 2547f07c77976b2bc9f9ea3ce0f75ce7a33c18a1 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 26 Jan 2023 07:38:58 -0800 Subject: [PATCH 437/469] Add FrameBuffer format. PiperOrigin-RevId: 504838580 --- mediapipe/framework/formats/BUILD | 12 + mediapipe/framework/formats/frame_buffer.cc | 176 ++++++++++++++ mediapipe/framework/formats/frame_buffer.h | 246 ++++++++++++++++++++ 3 files changed, 434 insertions(+) create mode 100644 mediapipe/framework/formats/frame_buffer.cc create mode 100644 mediapipe/framework/formats/frame_buffer.h diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 10aa3fca0..abd530b46 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -486,3 +486,15 @@ cc_test( "//mediapipe/gpu:disable_gpu": [], }), ) + +cc_library( + name = "frame_buffer", + srcs = ["frame_buffer.cc"], + hdrs = ["frame_buffer.h"], + deps = [ + "//mediapipe/framework/port:integral_types", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/mediapipe/framework/formats/frame_buffer.cc b/mediapipe/framework/formats/frame_buffer.cc new file mode 100644 index 000000000..a86d3f2ad --- /dev/null +++ b/mediapipe/framework/formats/frame_buffer.cc @@ -0,0 +1,176 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/framework/formats/frame_buffer.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace mediapipe { + +namespace { + +// Returns whether the input `format` is a supported YUV format. +bool IsSupportedYuvFormat(FrameBuffer::Format format) { + return format == FrameBuffer::Format::kNV21 || + format == FrameBuffer::Format::kNV12 || + format == FrameBuffer::Format::kYV12 || + format == FrameBuffer::Format::kYV21; +} + +// Returns supported 1-plane FrameBuffer in YuvData structure. +absl::StatusOr GetYuvDataFromOnePlaneFrameBuffer( + const FrameBuffer& source) { + if (!IsSupportedYuvFormat(source.format())) { + return absl::InvalidArgumentError( + "The source FrameBuffer format is not part of YUV420 family."); + } + + FrameBuffer::YuvData result; + const int y_buffer_size = + source.plane(0).stride.row_stride_bytes * source.dimension().height; + const int uv_buffer_size = + ((source.plane(0).stride.row_stride_bytes + 1) / 2) * + ((source.dimension().height + 1) / 2); + result.y_buffer = source.plane(0).buffer; + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = result.y_row_stride; + + if (source.format() == FrameBuffer::Format::kNV21) { + result.v_buffer = result.y_buffer + y_buffer_size; + result.u_buffer = result.v_buffer + 1; + result.uv_pixel_stride = 2; + // If y_row_stride equals to the frame width and is an odd value, + // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride. + if (result.y_row_stride == source.dimension().width && + result.y_row_stride % 2 == 1) { + result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2; + } + } else if (source.format() == FrameBuffer::Format::kNV12) { + result.u_buffer = result.y_buffer + y_buffer_size; + result.v_buffer = result.u_buffer + 1; + result.uv_pixel_stride = 2; + // If y_row_stride equals to the frame width and is an odd value, + // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride. + if (result.y_row_stride == source.dimension().width && + result.y_row_stride % 2 == 1) { + result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2; + } + } else if (source.format() == FrameBuffer::Format::kYV21) { + result.u_buffer = result.y_buffer + y_buffer_size; + result.v_buffer = result.u_buffer + uv_buffer_size; + result.uv_pixel_stride = 1; + result.uv_row_stride = (result.y_row_stride + 1) / 2; + } else if (source.format() == FrameBuffer::Format::kYV12) { + result.v_buffer = result.y_buffer + y_buffer_size; + result.u_buffer = result.v_buffer + uv_buffer_size; + result.uv_pixel_stride = 1; + result.uv_row_stride = (result.y_row_stride + 1) / 2; + } + return result; +} + +// Returns supported 2-plane FrameBuffer in YuvData structure. +absl::StatusOr GetYuvDataFromTwoPlaneFrameBuffer( + const FrameBuffer& source) { + if (source.format() != FrameBuffer::Format::kNV12 && + source.format() != FrameBuffer::Format::kNV21) { + return absl::InvalidArgumentError("Unsupported YUV planar format."); + } + + FrameBuffer::YuvData result; + // Y plane + result.y_buffer = source.plane(0).buffer; + // All plane strides + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = source.plane(1).stride.row_stride_bytes; + result.uv_pixel_stride = 2; + + if (source.format() == FrameBuffer::Format::kNV12) { + // Y and UV interleaved format + result.u_buffer = source.plane(1).buffer; + result.v_buffer = result.u_buffer + 1; + } else { + // Y and VU interleaved format + result.v_buffer = source.plane(1).buffer; + result.u_buffer = result.v_buffer + 1; + } + return result; +} + +// Returns supported 3-plane FrameBuffer in YuvData structure. Note that NV21 +// and NV12 are included in the supported Yuv formats. Technically, NV21 and +// NV12 should not be described by the 3-plane format. Historically, NV21 is +// used loosely such that it can also be used to describe YV21 format. For +// backwards compatibility, FrameBuffer supports NV21/NV12 with 3-plane format +// but such usage is discouraged +absl::StatusOr GetYuvDataFromThreePlaneFrameBuffer( + const FrameBuffer& source) { + if (!IsSupportedYuvFormat(source.format())) { + return absl::InvalidArgumentError( + "The source FrameBuffer format is not part of YUV420 family."); + } + + if (source.plane(1).stride.row_stride_bytes != + source.plane(2).stride.row_stride_bytes || + source.plane(1).stride.pixel_stride_bytes != + source.plane(2).stride.pixel_stride_bytes) { + return absl::InternalError("Unsupported YUV planar format."); + } + FrameBuffer::YuvData result; + if (source.format() == FrameBuffer::Format::kNV21 || + source.format() == FrameBuffer::Format::kYV12) { + // Y follow by VU order. The VU chroma planes can be interleaved or + // planar. + result.y_buffer = source.plane(0).buffer; + result.v_buffer = source.plane(1).buffer; + result.u_buffer = source.plane(2).buffer; + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = source.plane(1).stride.row_stride_bytes; + result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes; + } else { + // Y follow by UV order. The UV chroma planes can be interleaved or + // planar. + result.y_buffer = source.plane(0).buffer; + result.u_buffer = source.plane(1).buffer; + result.v_buffer = source.plane(2).buffer; + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = source.plane(1).stride.row_stride_bytes; + result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes; + } + return result; +} + +} // namespace + +absl::StatusOr FrameBuffer::GetYuvDataFromFrameBuffer( + const FrameBuffer& source) { + if (!IsSupportedYuvFormat(source.format())) { + return absl::InvalidArgumentError( + "The source FrameBuffer format is not part of YUV420 family."); + } + + if (source.plane_count() == 1) { + return GetYuvDataFromOnePlaneFrameBuffer(source); + } else if (source.plane_count() == 2) { + return GetYuvDataFromTwoPlaneFrameBuffer(source); + } else if (source.plane_count() == 3) { + return GetYuvDataFromThreePlaneFrameBuffer(source); + } + return absl::InvalidArgumentError( + "The source FrameBuffer must be consisted by 1, 2, or 3 planes"); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/formats/frame_buffer.h b/mediapipe/framework/formats/frame_buffer.h new file mode 100644 index 000000000..7578a0121 --- /dev/null +++ b/mediapipe/framework/formats/frame_buffer.h @@ -0,0 +1,246 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/port/integral_types.h" + +namespace mediapipe { + +// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera +// frame or still image) with buffer format information. FrameBuffer doesn't +// take ownership of the provided backing buffer. The caller is responsible to +// manage the backing buffer lifecycle for the lifetime of the FrameBuffer. +// +// Examples: +// +// // Create an metadata instance with no backing buffer. +// auto buffer = FrameBuffer::Create(/*planes=*/{}, dimension, kRGBA, +// KTopLeft); +// +// // Create an RGBA instance with backing buffer on single plane. +// FrameBuffer::Plane plane = +// {rgba_buffer, /*stride=*/{dimension.width * 4, 4}}; +// auto buffer = FrameBuffer::Create({plane}, dimension, kRGBA, kTopLeft); +// +// // Create an YUV instance with planar backing buffer. +// FrameBuffer::Plane y_plane = {y_buffer, /*stride=*/{dimension.width , 1}}; +// FrameBuffer::Plane uv_plane = {u_buffer, /*stride=*/{dimension.width, 2}}; +// auto buffer = FrameBuffer::Create({y_plane, uv_plane}, dimension, kNV21, +// kLeftTop); +class FrameBuffer { + public: + // Colorspace formats. + enum class Format { + kRGBA, + kRGB, + kNV12, + kNV21, + kYV12, + kYV21, + kGRAY, + kUNKNOWN + }; + + // Stride information. + struct Stride { + // The row stride in bytes. This is the distance between the start pixels of + // two consecutive rows in the image. + int row_stride_bytes; + // This is the distance between two consecutive pixel values in a row of + // pixels in bytes. It may be larger than the size of a single pixel to + // account for interleaved image data or padded formats. + int pixel_stride_bytes; + + bool operator==(const Stride& other) const { + return row_stride_bytes == other.row_stride_bytes && + pixel_stride_bytes == other.pixel_stride_bytes; + } + + bool operator!=(const Stride& other) const { return !operator==(other); } + }; + + // YUV data structure. + struct YuvData { + const uint8* y_buffer; + const uint8* u_buffer; + const uint8* v_buffer; + // Y buffer row stride in bytes. + int y_row_stride; + // U/V buffer row stride in bytes. + int uv_row_stride; + // U/V pixel stride in bytes. This is the distance between two consecutive + // u/v pixel values in a row. + int uv_pixel_stride; + }; + + // FrameBuffer content orientation follows EXIF specification. The name of + // each enum value defines the position of the 0th row and the 0th column of + // the image content. See http://jpegclub.org/exif_orientation.html for + // details. + enum class Orientation { + kTopLeft = 1, + kTopRight = 2, + kBottomRight = 3, + kBottomLeft = 4, + kLeftTop = 5, + kRightTop = 6, + kRightBottom = 7, + kLeftBottom = 8 + }; + + // Plane encapsulates buffer and stride information. + struct Plane { + const uint8* buffer; + Stride stride; + }; + + // Dimension information for the whole frame or a cropped portion of it. + struct Dimension { + // The width dimension in pixel unit. + int width; + // The height dimension in pixel unit. + int height; + + bool operator==(const Dimension& other) const { + return width == other.width && height == other.height; + } + + bool operator!=(const Dimension& other) const { + return width != other.width || height != other.height; + } + + bool operator>=(const Dimension& other) const { + return width >= other.width && height >= other.height; + } + + bool operator<=(const Dimension& other) const { + return width <= other.width && height <= other.height; + } + + // Swaps width and height. + void Swap() { + using std::swap; + swap(width, height); + } + + // Returns area represented by width * height. + int Size() const { return width * height; } + }; + + // Factory method for creating a FrameBuffer object from row-major backing + // buffers. + static std::unique_ptr Create(const std::vector& planes, + Dimension dimension, Format format, + Orientation orientation) { + return absl::make_unique(planes, dimension, format, + orientation); + } + + // Factory method for creating a FrameBuffer object from row-major movable + // backing buffers. + static std::unique_ptr Create(std::vector&& planes, + Dimension dimension, Format format, + Orientation orientation) { + return absl::make_unique(std::move(planes), dimension, format, + orientation); + } + + // Returns YuvData which contains the Y, U, and V buffer and their + // stride info from the input `source` FrameBuffer which is in the YUV family + // formats (e.g NV12, NV21, YV12, and YV21). + static absl::StatusOr GetYuvDataFromFrameBuffer( + const FrameBuffer& source); + + // Builds a FrameBuffer object from a row-major backing buffer. + // + // The FrameBuffer does not take ownership of the backing buffer. The backing + // buffer is read-only and the caller is responsible for maintaining the + // backing buffer lifecycle for the lifetime of FrameBuffer. + FrameBuffer(const std::vector& planes, Dimension dimension, + Format format, Orientation orientation) + : planes_(planes), + dimension_(dimension), + format_(format), + orientation_(orientation) {} + + // Builds a FrameBuffer object from a movable row-major backing buffer. + // + // The FrameBuffer does not take ownership of the backing buffer. The backing + // buffer is read-only and the caller is responsible for maintaining the + // backing buffer lifecycle for the lifetime of FrameBuffer. + FrameBuffer(std::vector&& planes, Dimension dimension, Format format, + Orientation orientation) + : planes_(std::move(planes)), + dimension_(dimension), + format_(format), + orientation_(orientation) {} + + // Copy constructor. + // + // FrameBuffer does not take ownership of the backing buffer. The copy + // constructor behaves the same way to only copy the buffer pointer and not + // take ownership of the backing buffer. + FrameBuffer(const FrameBuffer& frame_buffer) { + planes_.clear(); + for (int i = 0; i < frame_buffer.plane_count(); i++) { + planes_.push_back( + FrameBuffer::Plane{.buffer = frame_buffer.plane(i).buffer, + .stride = frame_buffer.plane(i).stride}); + } + dimension_ = frame_buffer.dimension(); + format_ = frame_buffer.format(); + orientation_ = frame_buffer.orientation(); + } + + // Returns number of planes. + int plane_count() const { return planes_.size(); } + + // Returns plane indexed by the input `index`. + Plane plane(int index) const { + if (index > -1 && static_cast(index) < planes_.size()) { + return planes_[index]; + } + return {}; + } + + // Returns FrameBuffer dimension. + Dimension dimension() const { return dimension_; } + + // Returns FrameBuffer format. + Format format() const { return format_; } + + // Returns FrameBuffer orientation. + Orientation orientation() const { return orientation_; } + + private: + std::vector planes_; + Dimension dimension_; + Format format_; + Orientation orientation_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ From 29001234d571d30a727e7bda2a6b04a0bfb8bb61 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 26 Jan 2023 10:43:36 -0800 Subject: [PATCH 438/469] Replace SourceOrNodeOutput with Source. PiperOrigin-RevId: 504883990 --- .../tasks/cc/components/processors/BUILD | 2 - .../classification_postprocessing_graph.cc | 16 +++-- .../embedding_postprocessing_graph.cc | 8 +-- mediapipe/tasks/cc/components/utils/BUILD | 6 -- .../components/utils/source_or_node_output.h | 66 ------------------- .../tasks/cc/vision/object_detector/BUILD | 1 - .../object_detector/object_detector_graph.cc | 9 +-- 7 files changed, 17 insertions(+), 91 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/utils/source_or_node_output.h diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index cec44a9e3..10bc0726a 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -48,7 +48,6 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/metadata:metadata_schema_cc", @@ -90,7 +89,6 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 5a0472f5c..cfb3b02cf 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -40,7 +40,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" @@ -68,7 +67,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; -using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; +using TensorsSource = mediapipe::api2::builder::Source>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -455,12 +454,13 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { } // If output tensors are quantized, they must be dequantized first. - TensorsSource dequantized_tensors(&tensors_in); + TensorsSource dequantized_tensors = tensors_in; if (options.has_quantized_outputs()) { GenericNode* tensors_dequantization_node = &graph.AddNode("TensorsDequantizationCalculator"); tensors_in >> tensors_dequantization_node->In(kTensorsTag); - dequantized_tensors = {tensors_dequantization_node, kTensorsTag}; + dequantized_tensors = tensors_dequantization_node->Out(kTensorsTag) + .Cast>(); } // If there are multiple classification heads, the output tensors need to be @@ -477,7 +477,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { auto* range = split_tensor_vector_options.add_ranges(); range->set_begin(i); range->set_end(i + 1); - split_tensors.emplace_back(split_tensor_vector_node, i); + split_tensors.push_back( + split_tensor_vector_node->Out(i).Cast>()); } dequantized_tensors >> split_tensor_vector_node->In(0); } else { @@ -494,8 +495,9 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { score_calibration_node->GetOptions() .CopyFrom(options.score_calibration_options().at(i)); split_tensors[i] >> score_calibration_node->In(kScoresTag); - calibrated_tensors.emplace_back(score_calibration_node, - kCalibratedScoresTag); + calibrated_tensors.push_back( + score_calibration_node->Out(kCalibratedScoresTag) + .Cast>()); } else { calibrated_tensors.emplace_back(split_tensors[i]); } diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index ad4881e12..7b023ba41 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -51,8 +50,6 @@ using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::ModelResources; -using TensorsSource = - ::mediapipe::tasks::SourceOrNodeOutput>; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; @@ -229,12 +226,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { Source> tensors_in, Source> timestamps_in, Graph& graph) { // If output tensors are quantized, they must be dequantized first. - TensorsSource dequantized_tensors(&tensors_in); + Source> dequantized_tensors = tensors_in; if (options.has_quantized_outputs()) { GenericNode& tensors_dequantization_node = graph.AddNode("TensorsDequantizationCalculator"); tensors_in >> tensors_dequantization_node.In(kTensorsTag); - dequantized_tensors = {&tensors_dequantization_node, kTensorsTag}; + dequantized_tensors = tensors_dequantization_node.Out(kTensorsTag) + .Cast>(); } // Adds TensorsToEmbeddingsCalculator. diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index 8bb5b8415..2e0ea3ce6 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -14,12 +14,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) -cc_library( - name = "source_or_node_output", - hdrs = ["source_or_node_output.h"], - deps = ["//mediapipe/framework/api2:builder"], -) - cc_library( name = "cosine_similarity", srcs = ["cosine_similarity.cc"], diff --git a/mediapipe/tasks/cc/components/utils/source_or_node_output.h b/mediapipe/tasks/cc/components/utils/source_or_node_output.h deleted file mode 100644 index 55805d5a3..000000000 --- a/mediapipe/tasks/cc/components/utils/source_or_node_output.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ - -#include "mediapipe/framework/api2/builder.h" - -namespace mediapipe { -namespace tasks { - -// Helper class representing either a Source object or a GenericNode output. -// -// Source and MultiSource (the output of a GenericNode) are widely incompatible, -// but being able to represent either of these in temporary variables and -// connect them later on facilitates graph building. -template -class SourceOrNodeOutput { - public: - SourceOrNodeOutput() = delete; - // The caller is responsible for ensuring 'source' outlives this object. - explicit SourceOrNodeOutput(mediapipe::api2::builder::Source* source) - : source_(source) {} - // The caller is responsible for ensuring 'node' outlives this object. - SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, - std::string tag) - : node_(node), tag_(tag) {} - // The caller is responsible for ensuring 'node' outlives this object. - SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, int index) - : node_(node), index_(index) {} - - // Connects the source or node output to the provided destination. - template - void operator>>(const U& dest) { - if (source_ != nullptr) { - *source_ >> dest; - } else { - if (index_ < 0) { - node_->Out(tag_) >> dest; - } else { - node_->Out(index_) >> dest; - } - } - } - - private: - mediapipe::api2::builder::Source* source_ = nullptr; - mediapipe::api2::builder::GenericNode* node_ = nullptr; - std::string tag_ = ""; - int index_ = -1; -}; - -} // namespace tasks -} // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 5269796ae..0238449c7 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -74,7 +74,6 @@ cc_library( "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index e5af7544d..cb85fc46f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -34,7 +34,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" @@ -69,7 +68,7 @@ using LabelItems = mediapipe::proto_ns::Map; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; using TensorsSource = - mediapipe::tasks::SourceOrNodeOutput>; + mediapipe::api2::builder::Source>; constexpr int kDefaultLocationsIndex = 0; constexpr int kDefaultCategoriesIndex = 1; @@ -584,7 +583,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { auto post_processing_specs, BuildPostProcessingSpecs(task_options, metadata_extractor)); // Calculators to perform score calibration, if specified in the metadata. - TensorsSource calibrated_tensors = {&inference, kTensorTag}; + TensorsSource calibrated_tensors = + inference.Out(kTensorTag).Cast>(); if (post_processing_specs.score_calibration_options.has_value()) { // Split tensors. auto* split_tensor_vector_node = @@ -623,7 +623,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { concatenate_tensor_vector_node->In(i); } } - calibrated_tensors = {concatenate_tensor_vector_node, 0}; + calibrated_tensors = + concatenate_tensor_vector_node->Out(0).Cast>(); } // Calculator to convert output tensors to a detection proto vector. // Connects TensorsToDetectionsCalculator's input stream to the output From 4d38557f116853ce8e90457d61c56b795a6ba86b Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 26 Jan 2023 12:30:05 -0800 Subject: [PATCH 439/469] Add MediaPipe Image Segmenter task for Web PiperOrigin-RevId: 504912518 --- mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/README.md | 17 + .../tasks/web/vision/image_segmenter/BUILD | 58 ++++ .../vision/image_segmenter/image_segmenter.ts | 300 ++++++++++++++++++ .../image_segmenter_options.d.ts | 41 +++ .../image_segmenter/image_segmenter_test.ts | 215 +++++++++++++ mediapipe/tasks/web/vision/index.ts | 3 + mediapipe/tasks/web/vision/types.ts | 1 + 8 files changed, 636 insertions(+) create mode 100644 mediapipe/tasks/web/vision/image_segmenter/BUILD create mode 100644 mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts create mode 100644 mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts create mode 100644 mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 8ba9c85b3..a229cbd2a 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -23,6 +23,7 @@ VISION_LIBS = [ "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/image_segmenter", "//mediapipe/tasks/web/vision/object_detector", ] diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index 51f43821c..9e86eafd3 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -39,6 +39,23 @@ const classifications = imageClassifier.classify(image); For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation. +## Image Segmentation + +The MediaPipe Image Segmenter lets you segment an image into categories. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const imageSegmenter = await ImageSegmenter.createFromModelPath(vision, + "model.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +imageSegmenter.segment(image, (masks, width, height) => { + ... +}); +``` + ## Gesture Recognition The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD new file mode 100644 index 000000000..d15fe63f1 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -0,0 +1,58 @@ +# This contains the MediaPipe Image Segmenter Task. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "image_segmenter", + srcs = ["image_segmenter.ts"], + deps = [ + ":image_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "image_segmenter_types", + srcs = ["image_segmenter_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "image_segmenter_test_lib", + testonly = True, + srcs = [ + "image_segmenter_test.ts", + ], + deps = [ + ":image_segmenter", + ":image_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + ], +) + +jasmine_node_test( + name = "image_segmenter_test", + tags = ["nomsan"], + deps = [":image_segmenter_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts new file mode 100644 index 000000000..4f81977eb --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -0,0 +1,300 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb'; +import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {ImageSegmenterOptions} from './image_segmenter_options'; + +export * from './image_segmenter_options'; +export {ImageSource}; // Used in the public API + +/** + * The ImageSegmenter returns the segmentation result as a Uint8Array (when + * the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for + * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved + * for future usage. + */ +export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture; + +/** + * A callback that receives the computed masks from the image segmenter. The + * callback either receives a single element array with a category mask (as a + * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`). + * The returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type SegmentationMaskCallback = + (masks: SegmentationMask[], width: number, height: number) => void; + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; +const IMAGEA_SEGMENTER_GRAPH = + 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs image segmentation on images. */ +export class ImageSegmenter extends VisionTaskRunner { + private userCallback: SegmentationMaskCallback = () => {}; + private readonly options: ImageSegmenterGraphOptionsProto; + private readonly segmenterOptions: SegmenterOptionsProto; + + /** + * Initializes the Wasm runtime and creates a new image segmenter from the + * provided options. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param imageSegmenterOptions The options for the Image Segmenter. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + imageSegmenterOptions: ImageSegmenterOptions): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + imageSegmenterOptions); + } + + /** + * Initializes the Wasm runtime and creates a new image segmenter based on + * the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new image segmenter based on + * the path to the model asset. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + this.options = new ImageSegmenterGraphOptionsProto(); + this.segmenterOptions = new SegmenterOptionsProto(); + this.options.setSegmenterOptions(this.segmenterOptions); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the image segmenter. + * + * Calling `setOptions()` with a subset of options only affects those + * options. You can reset an option back to its default value by + * explicitly setting it to `undefined`. + * + * @param options The options for the image segmenter. + */ + override setOptions(options: ImageSegmenterOptions): Promise { + // Note that we have to support both JSPB and ProtobufJS, hence we + // have to expliclity clear the values instead of setting them to + // `undefined`. + if (options.displayNamesLocale !== undefined) { + this.options.setDisplayNamesLocale(options.displayNamesLocale); + } else if ('displayNamesLocale' in options) { // Check for undefined + this.options.clearDisplayNamesLocale(); + } + + if (options.outputType === 'CONFIDENCE_MASK') { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); + } else { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CATEGORY_MASK); + } + + return super.applyOptions(options); + } + + /** + * Performs image segmentation on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `image`. + * + * @param image An image to process. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment(image: ImageSource, callback: SegmentationMaskCallback): void; + /** + * Performs image segmentation on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment( + image: ImageSource, imageProcessingOptions: ImageProcessingOptions, + callback: SegmentationMaskCallback): void; + segment( + image: ImageSource, + imageProcessingOptionsOrCallback: ImageProcessingOptions| + SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof imageProcessingOptionsOrCallback !== 'function' ? + imageProcessingOptionsOrCallback : + {}; + + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + imageProcessingOptionsOrCallback : + callback!; + this.processImageData(image, imageProcessingOptions); + this.userCallback = () => {}; + } + + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, timestamp: number, + callback: SegmentationMaskCallback): void; + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number, callback: SegmentationMaskCallback): void; + segmentForVideo( + videoFrame: ImageSource, + timestampOrImageProcessingOptions: number|ImageProcessingOptions, + timestampOrCallback: number|SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof timestampOrImageProcessingOptions !== 'number' ? + timestampOrImageProcessingOptions : + {}; + const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? + timestampOrImageProcessingOptions : + timestampOrCallback as number; + + this.userCallback = typeof timestampOrCallback === 'function' ? + timestampOrCallback : + callback!; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + this.userCallback = () => {}; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ImageSegmenterGraphOptionsProto.ext, this.options); + + const segmenterNode = new CalculatorGraphConfig.Node(); + segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH); + segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); + segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + segmenterNode.addOutputStream( + 'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM); + segmenterNode.setOptions(calculatorOptions); + + graphConfig.addNode(segmenterNode); + + this.graphRunner.attachImageVectorListener( + GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { + if (masks.length === 0) { + this.userCallback([], 0, 0); + } else { + this.userCallback( + masks.map(m => m.data), masks[0].width, masks[0].height); + } + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts new file mode 100644 index 000000000..c17e7e421 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Image Segmenter Task */ +export interface ImageSegmenterOptions extends VisionTaskOptions { + /** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ + displayNamesLocale?: string|undefined; + + /** + * The output type of segmentation results. + * + * The two supported modes are: + * - Category Mask: Gives a single output mask where each pixel represents + * the class which the pixel in the original image was + * predicted to belong to. + * - Confidence Mask: Gives a list of output masks (one for each class). For + * each mask, the pixel represents the prediction + * confidence, usually in the [0.0, 0.1] range. + * + * Defaults to `CATEGORY_MASK`. + */ + outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; +} diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts new file mode 100644 index 000000000..aa81be025 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -0,0 +1,215 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; + +import {ImageSegmenter} from './image_segmenter'; +import {ImageSegmenterOptions} from './image_segmenter_options'; + +class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + imageVectorListener: + ((images: WasmImage[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachImageVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('segmented_masks'); + this.imageVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageSegmenter', () => { + let imageSegmenter: ImageSegmenterFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageSegmenter = new ImageSegmenterFake(); + await imageSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(imageSegmenter); + verifyListenersRegistered(imageSegmenter); + }); + + it('reloads graph when settings are changed', async () => { + await imageSegmenter.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); + verifyListenersRegistered(imageSegmenter); + + await imageSegmenter.setOptions({displayNamesLocale: 'de'}); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); + verifyListenersRegistered(imageSegmenter); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageSegmenter.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageSegmenter, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + await imageSegmenter.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); + }); + + describe('setOptions()', () => { + interface TestCase { + optionName: keyof ImageSegmenterOptions; + fieldPath: string[]; + userValue: unknown; + graphValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'displayNamesLocale', + fieldPath: ['displayNamesLocale'], + userValue: 'en', + graphValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'outputType', + fieldPath: ['segmenterOptions', 'outputType'], + userValue: 'CONFIDENCE_MASK', + graphValue: 2, + defaultValue: 1 + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await imageSegmenter.setOptions( + {[testCase.optionName]: testCase.userValue}); + verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await imageSegmenter.setOptions( + {[testCase.optionName]: testCase.userValue}); + verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]); + await imageSegmenter.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + imageSegmenter, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + imageSegmenter.segment( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('supports category masks', (done) => { + const mask = new Uint8Array([1, 2, 3, 4]); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageSegmenter); + imageSegmenter.imageVectorListener!( + [ + {data: mask, width: 2, height: 2}, + ], + /* timestamp= */ 1337); + }); + + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(masks).toHaveSize(1); + expect(masks[0]).toEqual(mask); + expect(width).toEqual(2); + expect(height).toEqual(2); + done(); + }); + }); + + it('supports confidence masks', async () => { + const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); + const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); + + await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageSegmenter); + imageSegmenter.imageVectorListener!( + [ + {data: mask1, width: 2, height: 2}, + {data: mask2, width: 2, height: 2}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(masks).toHaveSize(2); + expect(masks[0]).toEqual(mask1); + expect(masks[1]).toEqual(mask2); + expect(width).toEqual(2); + expect(height).toEqual(2); + resolve(); + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 49f23c243..5a87c7a82 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -19,6 +19,7 @@ import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vis import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; +import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter'; import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; // Declare the variables locally so that Rollup in OSS includes them explicitly @@ -28,6 +29,7 @@ const GestureRecognizer = GestureRecognizerImpl; const HandLandmarker = HandLandmarkerImpl; const ImageClassifier = ImageClassifierImpl; const ImageEmbedder = ImageEmbedderImpl; +const ImageSegmenter = ImageSegementerImpl; const ObjectDetector = ObjectDetectorImpl; export { @@ -36,5 +38,6 @@ export { HandLandmarker, ImageClassifier, ImageEmbedder, + ImageSegmenter, ObjectDetector }; diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index dd1f58294..b9d951f60 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -19,4 +19,5 @@ export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/image_classifier/image_classifier'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/image_segmenter/image_segmenter'; export * from '../../../tasks/web/vision/object_detector/object_detector'; From c29ab7f083a8195b1a82bb39ec5abcab54b3e83c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 26 Jan 2023 13:34:36 -0800 Subject: [PATCH 440/469] Internal change PiperOrigin-RevId: 504928797 --- mediapipe/framework/BUILD | 1 + mediapipe/framework/subgraph.cc | 2 +- mediapipe/framework/subgraph.h | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index da8ef3b3e..e082ef2e6 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1051,6 +1051,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) diff --git a/mediapipe/framework/subgraph.cc b/mediapipe/framework/subgraph.cc index d0f018e1a..7cbde28bf 100644 --- a/mediapipe/framework/subgraph.cc +++ b/mediapipe/framework/subgraph.cc @@ -92,7 +92,7 @@ bool GraphRegistry::IsRegistered(const std::string& ns, } absl::StatusOr GraphRegistry::CreateByName( - const std::string& ns, const std::string& type_name, + absl::string_view ns, absl::string_view type_name, SubgraphContext* context) const { absl::StatusOr> maker = local_factories_.IsRegistered(ns, type_name) diff --git a/mediapipe/framework/subgraph.h b/mediapipe/framework/subgraph.h index b3e7d958b..5b1d9646a 100644 --- a/mediapipe/framework/subgraph.h +++ b/mediapipe/framework/subgraph.h @@ -20,6 +20,7 @@ #include "absl/base/macros.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/registration.h" @@ -187,7 +188,7 @@ class GraphRegistry { // Returns the specified graph config. absl::StatusOr CreateByName( - const std::string& ns, const std::string& type_name, + absl::string_view ns, absl::string_view type_name, SubgraphContext* context = nullptr) const; static GraphRegistry global_graph_registry; From 8531803462e98974d19f9c1c7d507ebdd4493ef1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 27 Jan 2023 11:08:39 +0530 Subject: [PATCH 441/469] Updated documentation of embedding containers --- .../tasks/ios/components/containers/sources/MPPEmbedding.h | 1 - .../tasks/ios/components/containers/sources/MPPEmbedding.m | 1 - .../ios/components/containers/sources/MPPEmbeddingResult.h | 2 +- .../ios/text/text_embedder/sources/MPPTextEmbedderOptions.h | 4 ++-- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h index a9db8e579..b2104990f 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h @@ -54,7 +54,6 @@ NS_SWIFT_NAME(Embedding) * * @return An instance of `MPPEmbedding` initialized with the given float embedding, quantized * embedding, head index and head name. - * */ - (instancetype)initWithFloatEmbedding:(nullable float *)floatEmbedding quantizedEmbedding:(nullable char *)quantizedEmbedding diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m index 642853ef1..a4c1e224a 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m @@ -20,7 +20,6 @@ quantizedEmbedding:(nullable char *)quantizedEmbedding headIndex:(NSInteger)headIndex headName:(nullable NSString *)headName { - // TODO: Should null check for embeddings be done here ? self = [super init]; if (self) { _headIndex = headIndex; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h index 3d5d48b9b..8fd9b9dff 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h @@ -30,7 +30,7 @@ NS_SWIFT_NAME(EmbeddingResult) * @brief The optional timestamp (in milliseconds) of the start of the chunk of data corresponding * to these results. * This is only used for embedding extraction on time series (e.g. audio embedder). In these use - * cases, the amount of data to process might exceed the maximum size that the model can process: to + * cases, the amount of data to process might exceed the maximum size that the model can process. To * solve this, the input data is split into multiple chunks starting at different timestamps. */ @property(nonatomic, readonly) NSInteger timestampMs; diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h index ce9fc8b20..fd2a7034c 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h @@ -29,7 +29,7 @@ NS_SWIFT_NAME(TextEmbedderptions) * Use this option only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. * In most cases, this is already the case and L2 norm is thus achieved through TF Lite inference. * - * NO by default. + * `NO` by default. */ @property(nonatomic) BOOL l2Normalize; @@ -38,7 +38,7 @@ NS_SWIFT_NAME(TextEmbedderptions) * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed to * have value in [-1.0, 1.0]. Use the `l2Normalize` property if this is not the case. * - * NO by default. + * `NO` by default. */ @property(nonatomic) BOOL quantize; From e059d55d29f3a4e99378ec7e857202f1345a92a5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 27 Jan 2023 01:50:57 -0800 Subject: [PATCH 442/469] Correctly check refCount in finalize. PiperOrigin-RevId: 505057866 --- .../java/com/google/mediapipe/framework/GraphTextureFrame.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index 63ea7854b..6a2c97b94 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -158,7 +158,7 @@ public class GraphTextureFrame implements TextureFrame { @Override protected void finalize() throws Throwable { - if (refCount >= 0 || nativeBufferHandle != 0) { + if (refCount > 0 || nativeBufferHandle != 0) { logger.atWarning().log("release was not called before finalize"); } if (!activeConsumerContextHandleSet.isEmpty()) { From 1df4511e9d5a8603835512318606262be949f9ff Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 27 Jan 2023 08:49:28 -0800 Subject: [PATCH 443/469] Add YuvImage as a GpuBuffer storage backend. PiperOrigin-RevId: 505128789 --- mediapipe/gpu/BUILD | 29 +++++++++++++++++++++++++++++ mediapipe/gpu/gpu_buffer_format.cc | 4 ++++ mediapipe/gpu/gpu_buffer_format.h | 12 ++++++++++++ 3 files changed, 45 insertions(+) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 55e9c98c2..702812718 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -441,6 +441,21 @@ cc_library( ], ) +cc_library( + name = "gpu_buffer_storage_yuv_image", + srcs = ["gpu_buffer_storage_yuv_image.cc"], + hdrs = ["gpu_buffer_storage_yuv_image.h"], + visibility = ["//visibility:public"], + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage", + "//mediapipe/framework/formats:yuv_image", + "//third_party/libyuv", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + ], +) + cc_library( name = "gpu_buffer_storage_ahwb", srcs = ["gpu_buffer_storage_ahwb.cc"], @@ -1187,3 +1202,17 @@ mediapipe_cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +mediapipe_cc_test( + name = "gpu_buffer_storage_yuv_image_test", + size = "small", + srcs = ["gpu_buffer_storage_yuv_image_test.cc"], + exclude_platforms = [ + "ios", + ], + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage_yuv_image", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 1dcd58e63..8e2e3858e 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -212,6 +212,10 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kRGBAHalf64: case GpuBufferFormat::kRGBAFloat128: + case GpuBufferFormat::kNV12: + case GpuBufferFormat::kNV21: + case GpuBufferFormat::kI420: + case GpuBufferFormat::kYV12: case GpuBufferFormat::kUnknown: return ImageFormat::UNKNOWN; } diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 06c5a0439..5d77afeb6 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -52,6 +52,14 @@ enum class GpuBufferFormat : uint32_t { kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible. kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'), kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'), + // 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling. + kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'), + // 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling. + kNV21 = MEDIAPIPE_FOURCC('N', 'V', '2', '1'), + // 8-bit Y plane + non-interleaved 8-bit U/V planes with 2x2 subsampling. + kI420 = MEDIAPIPE_FOURCC('I', '4', '2', '0'), + // 8-bit Y plane + non-interleaved 8-bit V/U planes with 2x2 subsampling. + kYV12 = MEDIAPIPE_FOURCC('Y', 'V', '1', '2'), }; #if !MEDIAPIPE_DISABLE_GPU @@ -111,6 +119,10 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) { return kCVPixelFormatType_64RGBAHalf; case GpuBufferFormat::kRGBAFloat128: return kCVPixelFormatType_128RGBAFloat; + case GpuBufferFormat::kNV12: + case GpuBufferFormat::kNV21: + case GpuBufferFormat::kI420: + case GpuBufferFormat::kYV12: case GpuBufferFormat::kUnknown: return -1; } From a6f6be95125f530a0cf044ec37c59c2970f664fb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 27 Jan 2023 09:20:07 -0800 Subject: [PATCH 444/469] Fix incorrect uint8 -> int8 conversion in JS cosine similarity. PiperOrigin-RevId: 505135368 --- .../tasks/web/components/utils/cosine_similarity.test.ts | 4 ++-- mediapipe/tasks/web/components/utils/cosine_similarity.ts | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts index f442caa20..2a82f388d 100644 --- a/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts @@ -70,12 +70,12 @@ describe('computeCosineSimilarity', () => { it('succeeds with quantized embeddings', () => { const u: Embedding = { - quantizedEmbedding: new Uint8Array([255, 128, 128, 128]), + quantizedEmbedding: new Uint8Array([127, 0, 0, 0]), headIndex: 0, headName: '' }; const v: Embedding = { - quantizedEmbedding: new Uint8Array([0, 128, 128, 128]), + quantizedEmbedding: new Uint8Array([128, 0, 0, 0]), headIndex: 0, headName: '' }; diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts index 1f483b9b6..b512478f4 100644 --- a/mediapipe/tasks/web/components/utils/cosine_similarity.ts +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -38,7 +38,7 @@ export function computeCosineSimilarity(u: Embedding, v: Embedding): number { } function convertToBytes(data: Uint8Array): number[] { - return Array.from(data, v => v - 128); + return Array.from(data, v => v > 127 ? v - 256 : v); } function compute(u: number[], v: number[]) { From dc3fdf6eb4ff66455497fe903ec5d4279c640518 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 27 Jan 2023 13:14:26 -0800 Subject: [PATCH 445/469] Internal change PiperOrigin-RevId: 505193224 --- ...863b622fe13612433fdf43f76547d5edda0c93001.diff | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff index 0cd2dffa4..e46ae9f81 100644 --- a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff +++ b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff @@ -11,4 +11,17 @@ index 9fceffe..e7f9d01 100644 - ], ":ios": [ "-framework Foundation", - ], \ No newline at end of file + ], +diff --git a/absl/types/compare.h b/absl/types/compare.h +index 19b076e..0201004 100644 +--- a/absl/types/compare.h ++++ b/absl/types/compare.h +@@ -84,7 +84,7 @@ enum class ncmp : value_type { unordered = -127 }; + // based on whether the feature is supported. Note: we can't use + // ABSL_INTERNAL_INLINE_CONSTEXPR here because the variables here are of + // incomplete types so they need to be defined after the types are complete. +-#ifdef __cpp_inline_variables ++#if defined(__cpp_inline_variables) && !(defined(_MSC_VER) && _MSC_VER <= 1916) + + // A no-op expansion that can be followed by a semicolon at class level. + #define ABSL_COMPARE_INLINE_BASECLASS_DECL(name) static_assert(true, "") \ No newline at end of file From 702cc0c42c8112730bf9f8b83a48bf02808eea7e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 27 Jan 2023 16:19:17 -0800 Subject: [PATCH 446/469] Change documentation to use shallow clones of the MP Repo PiperOrigin-RevId: 505234066 --- docs/getting_started/install.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index e630b073a..d7a028ec3 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -35,7 +35,7 @@ install --user six`. ```bash $ cd $HOME - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git # Change directory into MediaPipe root directory $ cd mediapipe @@ -287,7 +287,7 @@ build issues. 2. Checkout MediaPipe repository. ```bash - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git # Change directory into MediaPipe root directory $ cd mediapipe @@ -416,7 +416,7 @@ build issues. 3. Checkout MediaPipe repository. ```bash - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git $ cd mediapipe ``` @@ -590,7 +590,7 @@ next section. 7. Checkout MediaPipe repository. ``` - C:\Users\Username\mediapipe_repo> git clone https://github.com/google/mediapipe.git + C:\Users\Username\mediapipe_repo> git clone -–depth 1 https://github.com/google/mediapipe.git # Change directory into MediaPipe root directory C:\Users\Username\mediapipe_repo> cd mediapipe @@ -680,7 +680,7 @@ cameras. Alternatively, you use a video file as input. 6. Checkout MediaPipe repository. ```bash - username@DESKTOP-TMVLBJ1:~$ git clone https://github.com/google/mediapipe.git + username@DESKTOP-TMVLBJ1:~$ git clone -–depth 1 https://github.com/google/mediapipe.git username@DESKTOP-TMVLBJ1:~$ cd mediapipe ``` @@ -771,7 +771,7 @@ This will use a Docker image that will isolate mediapipe's installation from the 2. Build a docker image with tag "mediapipe". ```bash - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git $ cd mediapipe $ docker build --tag=mediapipe . From ee2f940e1fcc0e07c16afd773250b42ad41b36af Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 27 Jan 2023 18:06:09 -0800 Subject: [PATCH 447/469] Make TensorToVectorFloatCalculator compatible with unaligned tensors. No performance impact is expected, since the unaligned Eigen::TensorMap is used only to populate a std::vector. PiperOrigin-RevId: 505251810 --- mediapipe/calculators/tensorflow/BUILD | 1 + .../tensor_to_vector_float_calculator.cc | 2 +- .../tensor_to_vector_float_calculator_test.cc | 24 +++++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 0f8f8706a..4aec15dcb 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -1054,6 +1054,7 @@ cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/port:gtest_main", + "//mediapipe/util:packet_test_util", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:protos_all_cc", ], diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc index cd807b87b..ec7cd70fa 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc @@ -102,7 +102,7 @@ absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) { } auto output = absl::make_unique>(input_tensor.NumElements()); - const auto& tensor_values = input_tensor.flat(); + const auto& tensor_values = input_tensor.unaligned_flat(); for (int i = 0; i < input_tensor.NumElements(); ++i) { output->at(i) = tensor_values(i); } diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc index 69d3af60a..98ba4f020 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/util/packet_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" @@ -129,5 +130,28 @@ TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) { } } +TEST_F(TensorToVectorFloatCalculatorTest, AcceptsUnalignedTensors) { + SetUpRunner(/*tensor_is_2d=*/false, /*flatten_nd=*/false); + + const tf::TensorShape tensor_shape(std::vector{2, 5}); + tf::Tensor tensor(tf::DT_FLOAT, tensor_shape); + auto slice = tensor.Slice(1, 1).flat(); + for (int i = 0; i < 5; ++i) { + slice(i) = i; + } + + auto input_tensor = tensor.SubSlice(1); + // Ensure that the input tensor is unaligned. + ASSERT_FALSE(input_tensor.IsAligned()); + runner_->MutableInputs()->Index(0).packets.push_back( + MakePacket(input_tensor).At(Timestamp(5))); + + ASSERT_TRUE(runner_->Run().ok()); + + EXPECT_THAT(runner_->Outputs().Index(0).packets, + ElementsAre(PacketContainsTimestampAndPayload>( + Timestamp(5), std::vector({0, 1, 2, 3, 4})))); +} + } // namespace } // namespace mediapipe From 8c21dc02a6699bf0a4673ffe1adb757d01044248 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 30 Jan 2023 11:42:33 +0530 Subject: [PATCH 448/469] Updated to types of float and quantized embedding --- .../ios/components/containers/sources/MPPEmbedding.h | 8 ++++---- .../ios/components/containers/sources/MPPEmbedding.m | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h index b2104990f..77780fc5f 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h @@ -29,13 +29,13 @@ NS_SWIFT_NAME(Embedding) * @brief The Floating-point embedding. * Empty if the embedder was configured to perform scalar quantization. */ -@property(nonatomic, readonly, nullable) float *floatEmbedding; +@property(nonatomic, readonly, nullable) NSArray *floatEmbedding; /** * @brief The Quantized embedding. * Empty if the embedder was not configured to perform scalar quantization. */ -@property(nonatomic, readonly, nullable) char *quantizedEmbedding; +@property(nonatomic, readonly, nullable) NSData *quantizedEmbedding; /** The index of the embedder head these entries refer to. This is useful for multi-head models. */ @property(nonatomic, readonly) NSInteger headIndex; @@ -55,8 +55,8 @@ NS_SWIFT_NAME(Embedding) * @return An instance of `MPPEmbedding` initialized with the given float embedding, quantized * embedding, head index and head name. */ -- (instancetype)initWithFloatEmbedding:(nullable float *)floatEmbedding - quantizedEmbedding:(nullable char *)quantizedEmbedding +- (instancetype)initWithFloatEmbedding:(nullable NSArray *)floatEmbedding + quantizedEmbedding:(nullable NSData *)quantizedEmbedding headIndex:(NSInteger)headIndex headName:(nullable NSString *)headName NS_DESIGNATED_INITIALIZER; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m index a4c1e224a..17e216a2c 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m @@ -16,8 +16,8 @@ @implementation MPPEmbedding -- (instancetype)initWithFloatEmbedding:(nullable float *)floatEmbedding - quantizedEmbedding:(nullable char *)quantizedEmbedding +- (instancetype)initWithFloatEmbedding:(nullable NSArray *)floatEmbedding + quantizedEmbedding:(nullable NSData *)quantizedEmbedding headIndex:(NSInteger)headIndex headName:(nullable NSString *)headName { self = [super init]; From f9f6acffed8b58ebed637643ff67bda8a892e4ef Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 30 Jan 2023 09:15:06 -0800 Subject: [PATCH 449/469] Make NORM_RECT optional for GestureRecognizerGraph and add PALM_DETECTION output PORT PiperOrigin-RevId: 505712542 --- .../tasks/cc/vision/gesture_recognizer/BUILD | 1 + .../gesture_recognizer_graph.cc | 41 +++++++++++++++---- .../hand_detector/hand_detector_graph.cc | 17 ++++---- .../hand_landmarker/hand_landmarker_graph.cc | 16 ++++---- .../hand_landmarks_detector_graph.cc | 11 ++--- 5 files changed, 58 insertions(+), 28 deletions(-) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index d473a8dc3..7ffae6ff2 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -140,6 +140,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 2d949c410..b6f6c88da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -68,6 +69,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; +constexpr char kRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task"; constexpr char kHandGestureRecognizerBundleAssetName[] = "hand_gesture_recognizer.task"; @@ -77,6 +81,9 @@ struct GestureRecognizerOutputs { Source> handedness; Source> hand_landmarks; Source> hand_world_landmarks; + Source> hand_rects_next_frame; + Source> palm_rects; + Source> palm_detections; Source image; }; @@ -135,9 +142,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, // Inputs: // IMAGE - Image // Image to perform hand gesture recognition on. -// NORM_RECT - NormalizedRect +// NORM_RECT - NormalizedRect @Optional // Describes image rotation and region of image to perform landmarks -// detection on. +// detection on. If not provided, whole image is used for gesture +// recognition. // // Outputs: // HAND_GESTURES - std::vector @@ -208,11 +216,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); } - ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, - BuildGestureRecognizerGraph( - *sc->MutableOptions(), - graph[Input(kImageTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_gesture_recognition_output, + BuildGestureRecognizerGraph( + *sc->MutableOptions(), + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); hand_gesture_recognition_output.gesture >> graph[Output>(kHandGesturesTag)]; hand_gesture_recognition_output.handedness >> @@ -222,6 +231,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { hand_gesture_recognition_output.hand_world_landmarks >> graph[Output>(kWorldLandmarksTag)]; hand_gesture_recognition_output.image >> graph[Output(kImageTag)]; + hand_gesture_recognition_output.hand_rects_next_frame >> + graph[Output>(kRectNextFrameTag)]; + hand_gesture_recognition_output.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_gesture_recognition_output.palm_detections >> + graph[Output>(kPalmDetectionsTag)]; return graph.GetConfig(); } @@ -279,7 +294,17 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { /*handedness=*/handedness, /*hand_landmarks=*/hand_landmarks, /*hand_world_landmarks=*/hand_world_landmarks, - /*image=*/hand_landmarker_graph[Output(kImageTag)]}; + /*hand_rects_next_frame =*/ + hand_landmarker_graph[Output>( + kRectNextFrameTag)], + /*palm_rects =*/ + hand_landmarker_graph[Output>( + kPalmRectsTag)], + /*palm_detections =*/ + hand_landmarker_graph[Output>( + kPalmDetectionsTag)], + /*image=*/hand_landmarker_graph[Output(kImageTag)], + }; } }; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 49958e36b..d7163e331 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -150,9 +150,9 @@ void ConfigureRectTransformationCalculator( // Inputs: // IMAGE - Image // Image to perform detection on. -// NORM_RECT - NormalizedRect -// Describes image rotation and region of image to perform detection -// on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection on. If +// not provided, whole image is used for hand detection. // // Outputs: // PALM_DETECTIONS - std::vector @@ -197,11 +197,12 @@ class HandDetectorGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto hand_detection_outs, - BuildHandDetectionSubgraph( - sc->Options(), - *model_resources, graph[Input(kImageTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_detection_outs, + BuildHandDetectionSubgraph( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); hand_detection_outs.palm_detections >> graph[Output>(kPalmDetectionsTag)]; hand_detection_outs.hand_rects >> diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 05ad97efe..74d288ac1 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -136,9 +136,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, // Inputs: // IMAGE - Image // Image to perform hand landmarks detection on. -// NORM_RECT - NormalizedRect +// NORM_RECT - NormalizedRect @Optional // Describes image rotation and region of image to perform landmarks -// detection on. +// detection on. If not provided, whole image is used for hand landmarks +// detection. // // Outputs: // LANDMARKS: - std::vector @@ -218,11 +219,12 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); } - ASSIGN_OR_RETURN(auto hand_landmarker_outputs, - BuildHandLandmarkerGraph( - sc->Options(), - graph[Input(kImageTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_landmarker_outputs, + BuildHandLandmarkerGraph( + sc->Options(), + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); hand_landmarker_outputs.landmark_lists >> graph[Output>(kLandmarksTag)]; hand_landmarker_outputs.world_landmark_lists >> diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 4ea066aab..914bc30fc 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -243,11 +243,12 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto hand_landmark_detection_outs, - BuildSingleHandLandmarksDetectorGraph( - sc->Options(), - *model_resources, graph[Input(kImageTag)], - graph[Input(kHandRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_landmark_detection_outs, + BuildSingleHandLandmarksDetectorGraph( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kHandRectTag)], graph)); hand_landmark_detection_outs.hand_landmarks >> graph[Output(kLandmarksTag)]; hand_landmark_detection_outs.world_hand_landmarks >> From 2c4dece02320682bfb1334359e32cd124ca70cd9 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 30 Jan 2023 09:58:26 -0800 Subject: [PATCH 450/469] Internal change PiperOrigin-RevId: 505723714 --- WORKSPACE | 11 +++++------ ...001.diff => com_google_absl_windows_patch.diff} | 14 -------------- 2 files changed, 5 insertions(+), 20 deletions(-) rename third_party/{com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff => com_google_absl_windows_patch.diff} (59%) diff --git a/WORKSPACE b/WORKSPACE index bf5e4236b..e14473e50 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -22,21 +22,20 @@ bazel_skylib_workspace() load("@bazel_skylib//lib:versions.bzl", "versions") versions.check(minimum_bazel_version = "3.7.2") -# ABSL cpp library lts_2021_03_24, patch 2. +# ABSL cpp library lts_2023_01_25. http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz", ], - # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ - "@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff" + "@//third_party:com_google_absl_windows_patch.diff" ], patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20220623.1", - sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8" + strip_prefix = "abseil-cpp-20230125.0", + sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21" ) http_archive( diff --git a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff b/third_party/com_google_absl_windows_patch.diff similarity index 59% rename from third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff rename to third_party/com_google_absl_windows_patch.diff index e46ae9f81..a4b5b96bb 100644 --- a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff +++ b/third_party/com_google_absl_windows_patch.diff @@ -1,17 +1,3 @@ -diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel -index 9fceffe..e7f9d01 100644 ---- a/absl/time/internal/cctz/BUILD.bazel -+++ b/absl/time/internal/cctz/BUILD.bazel -@@ -69,8 +69,5 @@ cc_library( - "include/cctz/zone_info_source.h", - ], - linkopts = select({ -- ":osx": [ -- "-framework Foundation", -- ], - ":ios": [ - "-framework Foundation", - ], diff --git a/absl/types/compare.h b/absl/types/compare.h index 19b076e..0201004 100644 --- a/absl/types/compare.h From be3bddc6208dfe00b0ff66fa278561c543f1a803 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 31 Jan 2023 09:18:07 -0800 Subject: [PATCH 451/469] Add Text Embedder tests for text with different themes PiperOrigin-RevId: 506023265 --- .../text/text_embedder/text_embedder_test.cc | 27 +++++++++++++++++ .../text/textembedder/TextEmbedderTest.java | 20 +++++++++++++ .../python/test/text/text_embedder_test.py | 30 +++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index fa3d8af91..1ddea3358 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -139,5 +139,32 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) { MP_ASSERT_OK(text_embedder->Close()); } +TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileBert); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("When you go to this restaurant, they hold the " + "pancake upside-down before they hand it " + "to you. It's a great gimmick.")); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result1, + text_embedder->Embed( + "Let's make a plan to steal the declaration of independence.")); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + // TODO: The similarity should likely be lower + EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + } // namespace } // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java index b6d53c94d..48f214770 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java @@ -95,4 +95,24 @@ public class TextEmbedderTest { result1.embeddingResult().embeddings().get(0)); assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937); } + + @Test + public void classify_succeedsWithBertAndDifferentThemes() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); + + TextEmbedderResult result0 = + textEmbedder.embed( + "When you go to this restaurant, they hold the pancake upside-down before they hand " + + "it to you. It's a great gimmick."); + TextEmbedderResult result1 = + textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'"); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946); + } } diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 1346ba373..455deba03 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -192,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase): self._check_embedding_value(result1, expected_result1_value) self._check_cosine_similarity(result0, result1, expected_similarity) + def test_embed_with_mobile_bert_and_different_themes(self): + # Creates embedder. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE) + ) + base_options = _BaseOptions(model_asset_path=model_path) + options = _TextEmbedderOptions(base_options=base_options) + embedder = _TextEmbedder.create_from_options(options) + + # Extracts both embeddings. + text0 = ( + 'When you go to this restaurant, they hold the pancake upside-down ' + "before they hand it to you. It's a great gimmick." + ) + result0 = embedder.embed(text0) + + text1 = "Let's make a plan to steal the declaration of independence." + result1 = embedder.embed(text1) + + similarity = _TextEmbedder.cosine_similarity( + result0.embeddings[0], result1.embeddings[0] + ) + + # TODO: The similarity should likely be lower + self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE) + + # Closes the embedder explicitly when the embedder is not used in + # a context. + embedder.close() + if __name__ == '__main__': absltest.main() From 591eb204a677a0ba645c90fc1230f75b3361a441 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 31 Jan 2023 09:35:12 -0800 Subject: [PATCH 452/469] Internal change PiperOrigin-RevId: 506027661 --- mediapipe/framework/calculator_graph.h | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 04f9de45f..8d58ff312 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -53,14 +53,10 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/scheduler.h" #include "mediapipe/framework/thread_pool_executor.pb.h" +#include "mediapipe/gpu/gpu_service.h" namespace mediapipe { -#if !MEDIAPIPE_DISABLE_GPU -class GpuResources; -struct GpuSharedData; -#endif // !MEDIAPIPE_DISABLE_GPU - typedef absl::StatusOr StatusOrPoller; // The class representing a DAG of calculator nodes. From 5730dec260336c03666095bf0b57b7ce7456b677 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 31 Jan 2023 11:01:00 -0800 Subject: [PATCH 453/469] Internal change PiperOrigin-RevId: 506053206 --- mediapipe/calculators/core/BUILD | 1 + .../mediapipe/framework/PacketGetter.java | 30 +- .../framework/image/ByteBufferExtractor.java | 15 +- .../mediapipe/framework/image/MPImage.java | 4 + .../framework/jni/packet_getter_jni.cc | 125 +++-- .../framework/jni/packet_getter_jni.h | 11 + .../image_segmenter/image_segmenter_test.cc | 11 +- .../com/google/mediapipe/tasks/vision/BUILD | 25 + .../vision/imagesegmenter/AndroidManifest.xml | 8 + .../vision/imagesegmenter/ImageSegmenter.java | 462 ++++++++++++++++++ .../imagesegmenter/ImageSegmenterResult.java | 45 ++ .../vision/imagesegmenter/AndroidManifest.xml | 24 + .../tasks/vision/imagesegmenter/BUILD | 19 + .../imagesegmenter/ImageSegmenterTest.java | 427 ++++++++++++++++ 14 files changed, 1162 insertions(+), 45 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index df54c5800..ecfdd5d0b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1329,6 +1329,7 @@ cc_library( hdrs = ["merge_to_vector_calculator.h"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 7e66e0b75..92cf723e6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -199,6 +199,28 @@ public final class PacketGetter { return nativeGetImageData(packet.getNativeHandle(), buffer); } + /** Returns the size of Image list. This helps to determine size of allocated ByteBuffer array. */ + public static int getImageListSize(final Packet packet) { + return nativeGetImageListSize(packet.getNativeHandle()); + } + + /** + * Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer + * array has the the same size of image list packet, and assumes the output buffer stores pixels + * contiguously. It returns false if this assumption does not hold. + * + *

If deepCopy is true, it assumes the given buffersArray has allocated the required size of + * ByteBuffer to copy image data to. If false, the ByteBuffer will wrap the memory address of + * MediaPipe ImageFrame of graph output, and the ByteBuffer data is available only when MediaPipe + * graph is alive. + * + *

Note: this function does not assume the pixel format. + */ + public static boolean getImageList( + final Packet packet, ByteBuffer[] buffersArray, boolean deepCopy) { + return nativeGetImageList(packet.getNativeHandle(), buffersArray, deepCopy); + } + /** * Converts an RGB mediapipe image frame packet to an RGBA Byte buffer. * @@ -316,7 +338,8 @@ public final class PacketGetter { public static GraphTextureFrame getTextureFrameDeferredSync(final Packet packet) { return new GraphTextureFrame( nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ false), - packet.getTimestamp(), /* deferredSync= */true); + packet.getTimestamp(), + /* deferredSync= */ true); } private static native long nativeGetPacketFromReference(long nativePacketHandle); @@ -363,6 +386,11 @@ public final class PacketGetter { private static native boolean nativeGetImageData(long nativePacketHandle, ByteBuffer buffer); + private static native int nativeGetImageListSize(long nativePacketHandle); + + private static native boolean nativeGetImageList( + long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy); + private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer); // Retrieves the values that are in the VideoHeader. private static native int nativeGetVideoHeaderWidth(long nativepackethandle); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java index 748a10667..68c53b0c4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java @@ -50,7 +50,10 @@ public class ByteBufferExtractor { switch (container.getImageProperties().getStorageType()) { case MPImage.STORAGE_TYPE_BYTEBUFFER: ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + return byteBufferImageContainer + .getByteBuffer() + .asReadOnlyBuffer() + .order(ByteOrder.nativeOrder()); default: throw new IllegalArgumentException( "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not" @@ -74,7 +77,7 @@ public class ByteBufferExtractor { * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) { + public static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) { MPImageContainer container; MPImageProperties byteBufferProperties = MPImageProperties.builder() @@ -83,12 +86,16 @@ public class ByteBufferExtractor { .build(); if ((container = image.getContainer(byteBufferProperties)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + return byteBufferImageContainer + .getByteBuffer() + .asReadOnlyBuffer() + .order(ByteOrder.nativeOrder()); } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) - .asReadOnlyBuffer(); + .asReadOnlyBuffer() + .order(ByteOrder.nativeOrder()); } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; ByteBuffer byteBuffer = diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java index e17cc4d30..946beae37 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java @@ -67,6 +67,8 @@ public class MPImage implements Closeable { IMAGE_FORMAT_YUV_420_888, IMAGE_FORMAT_ALPHA, IMAGE_FORMAT_JPEG, + IMAGE_FORMAT_VEC32F1, + IMAGE_FORMAT_VEC32F2, }) @Retention(RetentionPolicy.SOURCE) public @interface MPImageFormat {} @@ -81,6 +83,8 @@ public class MPImage implements Closeable { public static final int IMAGE_FORMAT_YUV_420_888 = 7; public static final int IMAGE_FORMAT_ALPHA = 8; public static final int IMAGE_FORMAT_JPEG = 9; + public static final int IMAGE_FORMAT_VEC32F1 = 10; + public static final int IMAGE_FORMAT_VEC32F2 = 11; /** Specifies the image container type. Would be useful for choosing extractors. */ @IntDef({ diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index 737f6db72..234209b8c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -14,6 +14,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" @@ -39,6 +40,52 @@ template const T& GetFromNativeHandle(int64_t packet_handle) { return mediapipe::android::Graph::GetPacketFromHandle(packet_handle).Get(); } + +bool CopyImageDataToByteBuffer(JNIEnv* env, const mediapipe::ImageFrame& image, + jobject byte_buffer) { + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } + + // Assume byte buffer stores pixel data contiguously. + const int expected_buffer_size = image.Width() * image.Height() * + image.ByteDepth() * image.NumberOfChannels(); + if (buffer_size != expected_buffer_size) { + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Expected buffer size ", expected_buffer_size, + " got: ", buffer_size, ", width ", image.Width(), ", height ", + image.Height(), ", channels ", image.NumberOfChannels()))); + return false; + } + + switch (image.ByteDepth()) { + case 1: { + uint8* data = static_cast(buffer_data); + image.CopyToBuffer(data, expected_buffer_size); + break; + } + case 2: { + uint16* data = static_cast(buffer_data); + image.CopyToBuffer(data, expected_buffer_size); + break; + } + case 4: { + float* data = static_cast(buffer_data); + image.CopyToBuffer(data, expected_buffer_size); + break; + } + default: { + return false; + } + } + return true; +} + } // namespace JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetPacketFromReference)( @@ -298,46 +345,51 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( .GetImageFrameSharedPtr() .get() : GetFromNativeHandle(packet); + return CopyImageDataToByteBuffer(env, image, byte_buffer); +} - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - void* buffer_data = env->GetDirectBufferAddress(byte_buffer); - if (buffer_data == nullptr || buffer_size < 0) { - ThrowIfError(env, absl::InvalidArgumentError( - "input buffer does not support direct access")); +JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)( + JNIEnv* env, jobject thiz, jlong packet) { + const auto& image_list = + GetFromNativeHandle>(packet); + return image_list.size(); +} + +JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)( + JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array, + jboolean deep_copy) { + const auto& image_list = + GetFromNativeHandle>(packet); + if (env->GetArrayLength(byte_buffer_array) != image_list.size()) { + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Expected ByteBuffer array size: ", image_list.size(), + " but get ByteBuffer array size: ", + env->GetArrayLength(byte_buffer_array)))); return false; } - - // Assume byte buffer stores pixel data contiguously. - const int expected_buffer_size = image.Width() * image.Height() * - image.ByteDepth() * image.NumberOfChannels(); - if (buffer_size != expected_buffer_size) { - ThrowIfError( - env, absl::InvalidArgumentError(absl::StrCat( - "Expected buffer size ", expected_buffer_size, - " got: ", buffer_size, ", width ", image.Width(), ", height ", - image.Height(), ", channels ", image.NumberOfChannels()))); - return false; - } - - switch (image.ByteDepth()) { - case 1: { - uint8* data = static_cast(buffer_data); - image.CopyToBuffer(data, expected_buffer_size); - break; - } - case 2: { - uint16* data = static_cast(buffer_data); - image.CopyToBuffer(data, expected_buffer_size); - break; - } - case 4: { - float* data = static_cast(buffer_data); - image.CopyToBuffer(data, expected_buffer_size); - break; - } - default: { + for (int i = 0; i < image_list.size(); ++i) { + auto& image = *image_list[i].GetImageFrameSharedPtr().get(); + if (!image.IsContiguous()) { + ThrowIfError( + env, absl::InternalError("ImageFrame must store data contiguously to " + "be allocated as ByteBuffer.")); return false; } + if (deep_copy) { + jobject byte_buffer = reinterpret_cast( + env->GetObjectArrayElement(byte_buffer_array, i)); + if (!CopyImageDataToByteBuffer(env, image, byte_buffer)) { + return false; + } + } else { + // Assume byte buffer stores pixel data contiguously. + const int expected_buffer_size = image.Width() * image.Height() * + image.ByteDepth() * + image.NumberOfChannels(); + jobject image_data_byte_buffer = env->NewDirectByteBuffer( + image.MutablePixelData(), expected_buffer_size); + env->SetObjectArrayElement(byte_buffer_array, i, image_data_byte_buffer); + } } return true; } @@ -415,7 +467,8 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetAudioData)( int16 value = static_cast(audio_mat(channel, sample) * kMultiplier); // The java and native has the same byte order, by default is little - // Endian, we can safely copy data directly, we have tests to cover this. + // Endian, we can safely copy data directly, we have tests to cover + // this. env->SetByteArrayRegion(byte_data, offset, 2, reinterpret_cast(&value)); offset += 2; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h index 6a20d3daf..4602ebd59 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h @@ -106,6 +106,17 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageHeight)(JNIEnv* env, JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer); +// Return the vector size of std::vector. +JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)( + JNIEnv* env, jobject thiz, jlong packet); + +// Fill ByteBuffer[] from the Packet of std::vector. +// Before calling this, the byte_buffer_array needs to have the correct +// allocated size. +JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)( + JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array, + jboolean deep_copy); + // Before calling this, the byte_buffer needs to have the correct allocated // size. JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index f9618c1b1..c8c6e9036 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -257,10 +257,12 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } -TEST_F(ImageModeTest, SucceedsWithRotation) { +// TODO: fix this unit test after image segmenter handled post +// processing correctly with rotated image. +TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { MP_ASSERT_OK_AND_ASSIGN( - Image image, DecodeImageFromFile( - JoinPath("./", kTestDataDirectory, "cat_rotated.jpg"))); + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); @@ -271,7 +273,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { ImageSegmenter::Create(std::move(options))); ImageProcessingOptions image_processing_options; image_processing_options.rotation_degrees = -90; - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + segmenter->Segment(image, image_processing_options)); EXPECT_EQ(confidence_masks.size(), 21); cv::Mat expected_mask = diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index f469aed0c..0c30d7646 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -44,6 +44,7 @@ cc_binary( "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], @@ -176,6 +177,30 @@ android_library( ], ) +android_library( + name = "imagesegmenter", + srcs = [ + "imagesegmenter/ImageSegmenter.java", + "imagesegmenter/ImageSegmenterResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imagesegmenter/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + android_library( name = "imageembedder", srcs = [ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml new file mode 100644 index 000000000..6c8070364 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java new file mode 100644 index 000000000..8d07b7c68 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -0,0 +1,462 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagesegmenter; + +import android.content.Context; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto; +import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs image segmentation on images. + * + *

Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a + * user-defined callback function even for the synchronous API. This makes it possible for + * ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in + * the {@link ImageSegmenterOptions} for all {@link RunningMode}. + * + *

The API expects a TFLite model with,TFLite Model Metadata.. + * + *

    + *
  • Input image {@link MPImage} + *
      + *
    • The image that image segmenter runs on. + *
    + *
  • Output ImageSegmenterResult {@link ImageSgmenterResult} + *
      + *
    • An ImageSegmenterResult containing segmented masks. + *
    + *
+ */ +public final class ImageSegmenter extends BaseVisionTaskApi { + private static final String TAG = ImageSegmenter.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "GROUPED_SEGMENTATION:segmented_mask_out", + "IMAGE:image_out", + "SEGMENTATION:0:segmentation")); + private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; + + /** + * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. + * + * @param context an Android {@link Context}. + * @param segmenterOptions an {@link ImageSegmenterOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageSegmenter} creation. + */ + public static ImageSegmenter createFromOptions( + Context context, ImageSegmenterOptions segmenterOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageSegmenterResult convertToTaskResult(List packets) + throws MediaPipeException { + if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { + return ImageSegmenterResult.create( + new ArrayList<>(), + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); + } + List segmentedMasks = new ArrayList<>(); + int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int imageFormat = + segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK + ? MPImage.IMAGE_FORMAT_VEC32F1 + : MPImage.IMAGE_FORMAT_ALPHA; + int imageListSize = + PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + if (!PacketGetter.getImageList( + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting segmented masks. It usually results from incorrect" + + " options of unsupported OutputType of given model."); + } + for (ByteBuffer buffer : buffersArray) { + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, imageFormat); + segmentedMasks.add(builder.build()); + } + + return ImageSegmenterResult.create( + segmentedMasks, + BaseVisionTaskApi.generateResultTimestampMs( + segmenterOptions.runningMode(), + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + handler.setResultListener(segmenterOptions.resultListener()); + segmenterOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(ImageSegmenter.class.getSimpleName()) + .setTaskRunningModeName(segmenterOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(segmenterOptions) + .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageSegmenter(runner, segmenterOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageSegmenter} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs image segmentation on the provided single image with default image processing options, + * i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the + * {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java + * doc for input image format. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public void segment(MPImage image) { + segment(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs image segmentation on the provided single image, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method + * when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO + * update java doc for input image format. + * + *

{@link HandLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + ImageSegmenterResult unused = + (ImageSegmenterResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs image segmentation on the provided video frame with default image processing options, + * i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the + * {@link HandLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void segmentForVideo(MPImage image, long timestampMs) { + segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs image segmentation on the provided video frame, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method + * when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link HandLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void segmentForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + ImageSegmenterResult unused = + (ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform hand landmarks detection with default image processing + * options, i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the + * {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the image segmenter. The input timestamps must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void segmentAsync(MPImage image, long timestampMs) { + segmentAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform image segmentation, and the results will be available via the + * {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when + * the {@link ImageSegmenter} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the image segmenter. The input timestamps must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void segmentAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up an {@link ImageSegmenter}. */ + @AutoValue + public abstract static class ImageSegmenterOptions extends TaskOptions { + + /** Builder for {@link ImageSegmenterOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the image segmenter task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the image segmenter task. Default to the image mode. Image + * segmenter has three modes: + * + *
    + *
  • IMAGE: The mode for segmenting image on single image inputs. + *
  • VIDEO: The mode for segmenting image on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for segmenting image on a live stream of input data, such + * as from camera. In this mode, {@code setResultListener} must be called to set up a + * listener to receive the recognition results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** + * The locale to use for display names specified through the TFLite Model Metadata, if any. + * Defaults to English. + */ + public abstract Builder setDisplayNamesLocale(String value); + + /** The output type from image segmenter. */ + public abstract Builder setOutputType(OutputType value); + + /** + * Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline + * is done processing an image. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional {@link ErrorListener}}. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract ImageSegmenterOptions autoBuild(); + + /** + * Validates and builds the {@link ImageSegmenterOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image segmenter is + * in the live stream mode. + */ + public final ImageSegmenterOptions build() { + ImageSegmenterOptions options = autoBuild(); + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract String displayNamesLocale(); + + abstract OutputType outputType(); + + abstract ResultListener resultListener(); + + abstract Optional errorListener(); + + /** The output type of segmentation results. */ + public enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK + } + + public static Builder builder() { + return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setDisplayNamesLocale("en") + .setOutputType(OutputType.CATEGORY_MASK) + .setResultListener((result, image) -> {}); + } + + /** + * Converts an {@link ImageSegmenterOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder taskOptionsBuilder = + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()) + .setDisplayNamesLocale(displayNamesLocale()); + + SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = + SegmenterOptionsProto.SegmenterOptions.newBuilder(); + if (outputType() == OutputType.CONFIDENCE_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); + } else if (outputType() == OutputType.CATEGORY_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); + } + // TODO: remove this once activation is handled in metadata and grpah level. + segmenterOptionsBuilder.setActivation( + SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX); + taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); + return CalculatorOptions.newBuilder() + .setExtension( + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("ImageSegmenter doesn't support region-of-interest."); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java new file mode 100644 index 000000000..40fb93dd1 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -0,0 +1,45 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagesegmenter; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.Collections; +import java.util.List; + +/** Represents the segmentation results generated by {@link ImageSegmenter}. */ +@AutoValue +public abstract class ImageSegmenterResult implements TaskResult { + + /** + * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage. + * + * @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType + * is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is + * CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static ImageSegmenterResult create(List segmentations, long timestampMs) { + return new AutoValue_ImageSegmenterResult( + Collections.unmodifiableList(segmentations), timestampMs); + } + + public abstract List segmentations(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml new file mode 100644 index 000000000..c641d446f --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/BUILD new file mode 100644 index 000000000..c14486766 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java new file mode 100644 index 000000000..c11bb1f31 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -0,0 +1,427 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagesegmenter; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Color; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapExtractor; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferExtractor; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageSegmenter}. */ +@RunWith(Suite.class) +@SuiteClasses({ImageSegmenterTest.General.class, ImageSegmenterTest.RunningModeTest.class}) +public class ImageSegmenterTest { + private static final String DEEPLAB_MODEL_FILE = "deeplabv3.tflite"; + private static final String SELFIE_128x128_MODEL_FILE = "selfie_segm_128_128_3.tflite"; + private static final String SELFIE_144x256_MODEL_FILE = "selfie_segm_144_256_3.tflite"; + private static final String CAT_IMAGE = "cat.jpg"; + private static final float GOLDEN_MASK_SIMILARITY = 0.96f; + private static final int MAGNIFICATION_FACTOR = 10; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageSegmenterTest { + + @Test + public void segment_successWithCategoryMask() throws Exception { + final String inputImageName = "segmentation_input_rotation0.jpg"; + final String goldenImageName = "segmentation_golden_rotation0.png"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(1); + MPImage actualMaskBuffer = actualResult.segmentations().get(0); + verifyCategoryMask( + actualMaskBuffer, + expectedMaskBuffer, + GOLDEN_MASK_SIMILARITY, + MAGNIFICATION_FACTOR); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + @Test + public void segment_successWithConfidenceMask() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(21); + // Cat category index 8. + MPImage actualMaskBuffer = actualResult.segmentations().get(8); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + @Test + public void segment_successWith128x128Segmentation() throws Exception { + final String inputImageName = "mozart_square.jpg"; + final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(2); + // Selfie category index 1. + MPImage actualMaskBuffer = actualResult.segmentations().get(1); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + // TODO: enable this unit test once activation option is supported in metadata. + // @Test + // public void segment_successWith144x256Segmentation() throws Exception { + // final String inputImageName = "mozart_square.jpg"; + // final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; + // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + // ImageSegmenterOptions options = + // ImageSegmenterOptions.builder() + // .setBaseOptions( + // BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) + // .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + // .setActivation(ImageSegmenterOptions.Activation.NONE) + // .setResultListener( + // (actualResult, inputImage) -> { + // List segmentations = actualResult.segmentations(); + // assertThat(segmentations.size()).isEqualTo(1); + // MPImage actualMaskBuffer = actualResult.segmentations().get(0); + // verifyConfidenceMask( + // actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + // }) + // .build(); + // ImageSegmenter imageSegmenter = + // ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), + // options); + // imageSegmenter.segment(getImageFromAsset(inputImageName)); + // } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageSegmenterTest { + @Test + public void segment_failsWithCallingWrongApiInImageMode() throws Exception { + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentForVideo( + getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void segment_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void segment_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((result, inputImage) -> {}) + .build(); + + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentForVideo( + getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void segment_successWithImageMode() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.IMAGE) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(21); + // Cat category index 8. + MPImage actualMaskBuffer = actualResult.segmentations().get(8); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + @Test + public void segment_successWithVideoMode() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.VIDEO) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(21); + // Cat category index 8. + MPImage actualMaskBuffer = actualResult.segmentations().get(8); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + for (int i = 0; i < 3; i++) { + imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i); + } + } + + @Test + public void segment_successWithLiveStreamMode() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage image = getImageFromAsset(inputImageName); + MPImage expectedResult = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (segmenterResult, inputImage) -> { + verifyConfidenceMask( + segmenterResult.segmentations().get(8), + expectedResult, + GOLDEN_MASK_SIMILARITY); + }) + .build(); + try (ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + imageSegmenter.segmentAsync(image, /* timestampsMs= */ i); + } + } + } + + @Test + public void segment_failsWithOutOfOrderInputTimestamps() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage image = getImageFromAsset(inputImageName); + MPImage expectedResult = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (segmenterResult, inputImage) -> { + verifyConfidenceMask( + segmenterResult.segmentations().get(8), + expectedResult, + GOLDEN_MASK_SIMILARITY); + }) + .build(); + try (ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageSegmenter.segmentAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageSegmenter.segmentAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + } + + private static void verifyCategoryMask( + MPImage actualMask, MPImage goldenMask, float similarityThreshold, int magnificationFactor) { + assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth()); + assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight()); + ByteBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask); + Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask); + int consistentPixels = 0; + final int numPixels = actualMask.getWidth() * actualMask.getHeight(); + actualMaskBuffer.rewind(); + for (int y = 0; y < actualMask.getHeight(); y++) { + for (int x = 0; x < actualMask.getWidth(); x++) { + // RGB values are the same in the golden mask image. + consistentPixels += + actualMaskBuffer.get() * magnificationFactor + == Color.red(goldenMaskBitmap.getPixel(x, y)) + ? 1 + : 0; + } + } + assertThat((float) consistentPixels / numPixels).isGreaterThan(similarityThreshold); + } + + private static void verifyConfidenceMask( + MPImage actualMask, MPImage goldenMask, float similarityThreshold) { + assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth()); + assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight()); + FloatBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask).asFloatBuffer(); + Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask); + FloatBuffer goldenMaskBuffer = getByteBufferFromBitmap(goldenMaskBitmap).asFloatBuffer(); + assertThat( + calculateSoftIOU( + actualMaskBuffer, goldenMaskBuffer, actualMask.getWidth() * actualMask.getHeight())) + .isGreaterThan((double) similarityThreshold); + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static ByteBuffer getByteBufferFromBitmap(Bitmap bitmap) { + ByteBuffer byteBuffer = ByteBuffer.allocateDirect(bitmap.getWidth() * bitmap.getHeight() * 4); + for (int y = 0; y < bitmap.getHeight(); y++) { + for (int x = 0; x < bitmap.getWidth(); x++) { + byteBuffer.putFloat((float) Color.red(bitmap.getPixel(x, y)) / 255.f); + } + } + byteBuffer.rewind(); + return byteBuffer; + } + + private static double calculateSum(FloatBuffer m) { + m.rewind(); + double sum = 0; + while (m.hasRemaining()) { + sum += m.get(); + } + m.rewind(); + return sum; + } + + private static FloatBuffer multiply(FloatBuffer m1, FloatBuffer m2, int bufferSize) { + m1.rewind(); + m2.rewind(); + FloatBuffer buffer = FloatBuffer.allocate(bufferSize); + while (m1.hasRemaining()) { + buffer.put(m1.get() * m2.get()); + } + m1.rewind(); + m2.rewind(); + buffer.rewind(); + return buffer; + } + + private static double calculateSoftIOU(FloatBuffer m1, FloatBuffer m2, int bufferSize) { + double intersectionSum = calculateSum(multiply(m1, m2, bufferSize)); + double m1m1 = calculateSum(multiply(m1, m1.duplicate(), bufferSize)); + double m2m2 = calculateSum(multiply(m2, m2.duplicate(), bufferSize)); + double unionSum = m1m1 + m2m2 - intersectionSum; + return unionSum > 0.0 ? intersectionSum / unionSum : 0.0; + } +} From b53acf626759998ec3463f42dcb089c164f4b5f3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 31 Jan 2023 11:18:30 -0800 Subject: [PATCH 454/469] Internal change PiperOrigin-RevId: 506059384 --- mediapipe/framework/calculator_graph.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 8d58ff312..04f9de45f 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -53,10 +53,14 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/scheduler.h" #include "mediapipe/framework/thread_pool_executor.pb.h" -#include "mediapipe/gpu/gpu_service.h" namespace mediapipe { +#if !MEDIAPIPE_DISABLE_GPU +class GpuResources; +struct GpuSharedData; +#endif // !MEDIAPIPE_DISABLE_GPU + typedef absl::StatusOr StatusOrPoller; // The class representing a DAG of calculator nodes. From d283e6a05abcba303884f1f7232c1ac64597554b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 31 Jan 2023 18:41:42 -0800 Subject: [PATCH 455/469] Support downloading model files on-demand from GCS in model_maker PiperOrigin-RevId: 506174708 --- mediapipe/model_maker/python/core/utils/BUILD | 1 + .../python/core/utils/file_util.py | 82 +++++++++++++++++++ .../python/core/utils/file_util_test.py | 55 +++++++++++++ 3 files changed, 138 insertions(+) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 492bba0a9..3c9107dba 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -61,6 +61,7 @@ py_test( name = "file_util_test", srcs = ["file_util_test.py"], data = ["//mediapipe/model_maker/python/core/utils/testdata"], + tags = ["requires-net:external"], deps = [":file_util"], ) diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index 66addad54..29d11ebbe 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -13,11 +13,93 @@ # limitations under the License. """Utilities for files.""" +import dataclasses import os +import pathlib +import shutil +import tarfile +import tempfile +import requests # resources dependency +_TEMPDIR_FOLDER = 'model_maker' + + +@dataclasses.dataclass +class DownloadedFiles: + """File(s) that are downloaded from a url into a local directory. + + If `is_folder` is True: + 1. `path` should be a folder + 2. `url` should point to a .tar.gz file which contains a single folder at + the root level. + + Attributes: + path: Relative path in local directory. + url: GCS url to download the file(s). + is_folder: Whether the path and url represents a folder. + """ + + path: str + url: str + is_folder: bool = False + + def get_path(self) -> str: + """Gets the path of files saved in a local directory. + + If the path doesn't exist, this method will download the file(s) from the + provided url. The path is not cleaned up so it can be reused for subsequent + calls to the same path. + Folders are expected to be zipped in a .tar.gz file which will be extracted + into self.path in the local directory. + + Raises: + RuntimeError: If the extracted folder does not have a singular root + directory. + + Returns: + The absolute path to the downloaded file(s) + """ + tmpdir = tempfile.gettempdir() + absolute_path = pathlib.Path( + os.path.join(tmpdir, _TEMPDIR_FOLDER, self.path) + ) + if not absolute_path.exists(): + print(f'Downloading {self.url} to {absolute_path}') + r = requests.get(self.url, allow_redirects=True) + if self.is_folder: + # Use tempf to store the downloaded .tar.gz file + tempf = tempfile.NamedTemporaryFile(suffix='.tar.gz', mode='wb') + tempf.write(r.content) + tarf = tarfile.open(tempf.name) + # Use tmpdir to store the extracted contents of the .tar.gz file + with tempfile.TemporaryDirectory() as tmpdir: + tarf.extractall(tmpdir) + tarf.close() + tempf.close() + subdirs = os.listdir(tmpdir) + # Make sure tmpdir only has one subdirectory + if len(subdirs) > 1 or not os.path.isdir( + os.path.join(tmpdir, subdirs[0]) + ): + raise RuntimeError( + f"Extracted folder from {self.url} doesn't contain a " + f'single root directory: {subdirs}' + ) + # Create the parent dir of absolute_path and copy the contents of the + # top level dir in the .tar.gz file into absolute_path. + pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) + shutil.copytree(os.path.join(tmpdir, subdirs[0]), absolute_path) + else: + pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) + with open(absolute_path, 'wb') as f: + f.write(r.content) + return str(absolute_path) + + +# TODO Remove after text_classifier supports downloading on demand. def get_absolute_path(file_path: str) -> str: """Gets the absolute path of a file in the model_maker directory. diff --git a/mediapipe/model_maker/python/core/utils/file_util_test.py b/mediapipe/model_maker/python/core/utils/file_util_test.py index 4a2d6dcfb..f9f4a5954 100644 --- a/mediapipe/model_maker/python/core/utils/file_util_test.py +++ b/mediapipe/model_maker/python/core/utils/file_util_test.py @@ -12,13 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import tempfile +from unittest import mock as unittest_mock from absl.testing import absltest +import requests + from mediapipe.model_maker.python.core.utils import file_util class FileUtilTest(absltest.TestCase): + def setUp(self): + super().setUp() + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + + def test_get_path(self): + path = 'gesture_recognizer/hand_landmark_full.tflite' + url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite' + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + self.assertGreater(os.path.getsize(model_path), 0) + + def test_get_path_folder(self): + folder_contents = [ + 'keras_metadata.pb', + 'saved_model.pb', + 'assets/vocab.txt', + 'variables/variables.data-00000-of-00001', + 'variables/variables.index', + ] + path = 'text_classifier/mobilebert_tiny' + url = ( + 'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz' + ) + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=True) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + for file_name in folder_contents: + file_path = os.path.join(model_path, file_name) + self.assertTrue(os.path.exists(file_path)) + self.assertGreater(os.path.getsize(file_path), 0) + + @unittest_mock.patch.object(requests, 'get', wraps=requests.get) + def test_get_path_multiple_calls(self, mock_get): + path = 'gesture_recognizer/hand_landmark_full.tflite' + url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite' + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + self.assertGreater(os.path.getsize(model_path), 0) + model_path_2 = downloaded_files.get_path() + self.assertEqual(model_path, model_path_2) + self.assertEqual(mock_get.call_count, 1) + def test_get_absolute_path(self): test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt' absolute_path = file_util.get_absolute_path(test_file) From 24bd104b0f6093b21db59d1ef00ebd2d1445daed Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:46:11 +0530 Subject: [PATCH 456/469] Added MPPEmbedding Helpers --- .../ios/components/containers/utils/BUILD | 11 ++++ .../utils/sources/MPPEmbedding+Helpers.h | 26 ++++++++ .../utils/sources/MPPEmbedding+Helpers.mm | 62 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD index 923a8e013..e0989530e 100644 --- a/mediapipe/tasks/ios/components/containers/utils/BUILD +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -38,3 +38,14 @@ objc_library( "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", ], ) + +objc_library( + name = "MPPEmbeddingHelpers", + srcs = ["sources/MPPEmbedding+Helpers.mm"], + hdrs = ["sources/MPPEmbedding+Helpers.h"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/containers:MPPEmbedding", + ] +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h new file mode 100644 index 000000000..9ff5455d9 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPEmbedding (Helpers) + ++ (MPPEmbedding *)embeddingWithProto:(const ::mediapipe::tasks::components::containers::proto::Embedding &)embeddingProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm new file mode 100644 index 000000000..4b2acff7c --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm @@ -0,0 +1,62 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h" + +#include + +namespace { +using EmbeddingProto = ::mediapipe::tasks::components::containers::proto::Embedding; +} + +@implementation MPPEmbedding (Helpers) + ++ (MPPEmbedding *)embeddingWithProto:(const EmbeddingProto &)embeddingProto { + NSString *categoryName; + NSString *displayName; + + NSMutableArray *floatEmbedding; + NSData *quantizedEmbedding; + + if (embeddingProto.has_float_embedding()) { + floatEmbedding = [NSMutableArray arrayWithCapacity:embeddingProto.float_embedding().values_size()]; + const auto floatEmbeddingValues = embeddingProto.float_embedding().values(); + + for (const auto value : embeddingProto.float_embedding().values()) { + [floatEmbedding addObject:[NSNumber numberWithFloat:value]]; + } + } + + if (embeddingProto.has_quantized_embedding()) { + const std::string& cppQuantizedEmbedding = + embeddingProto.quantized_embedding().values().data(); + + const char *cppQuantizedEmbeddingCString = cppQuantizedEmbedding.c_str(); + quantizedEmbedding = [NSData dataWithBytes:cppQuantizedEmbeddingCString length:sizeof(cppQuantizedEmbeddingCString)]; + } + + NSString *headName; + + if (embeddingProto.has_head_name()) { + headName = [NSString stringWithCppString:embeddingProto.head_name()]; + } + + return [[MPPEmbedding alloc] initWithFloatEmbedding:floatEmbedding + quantizedEmbedding:quantizedEmbedding + headIndex:embeddingProto.head_index() + headName:headName]; +} + +@end From ffc9f1d47e6d81f6f0c99812f8c13853b5f25957 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:47:23 +0530 Subject: [PATCH 457/469] Added MPPEmbeddingResultHelpers --- .../ios/components/containers/utils/BUILD | 12 ++++++ .../sources/MPPEmbeddingResult+Helpers.h | 26 ++++++++++++ .../sources/MPPEmbeddingResult+Helpers.mm | 42 +++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD index e0989530e..567784acd 100644 --- a/mediapipe/tasks/ios/components/containers/utils/BUILD +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -49,3 +49,15 @@ objc_library( "//mediapipe/tasks/ios/components/containers:MPPEmbedding", ] ) + +objc_library( + name = "MPPEmbeddingResultHelpers", + srcs = ["sources/MPPEmbeddingResult+Helpers.mm"], + hdrs = ["sources/MPPEmbeddingResult+Helpers.h"], + deps = [ + ":MPPEmbeddingHelpers", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/containers:MPPEmbeddingResult", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h new file mode 100644 index 000000000..6ec26b764 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPEmbeddingResult (Helpers) + ++ (MPPEmbeddingResult *)embeddingResultWithProto:(const ::mediapipe::tasks::components::containers::proto::EmbeddingResult &)embeddingResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm new file mode 100644 index 000000000..385b74efb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm @@ -0,0 +1,42 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h" + +namespace { +using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +} + +@implementation MPPEmbeddingResult (Helpers) + ++ (MPPEmbeddingResult *)embeddingResultWithProto:(const EmbeddingResultProto &)embeddingResultProto { + NSMutableArray *embeddings = [NSMutableArray + arrayWithCapacity:(NSUInteger)embeddingResultProto.embeddings_size()]; + for (const auto &embeddingProto : embeddingResultProto.embeddings()) { + [embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]]; + } + + NSInteger timestampMs = 0; + if (embeddingResultProto.has_timestamp_ms()) { + timestampMs = (NSInteger)embeddingResultProto.timestamp_ms(); + } + + return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings + timestampMs:timestampMs]; +} + +@end From 69809e218155494583cf53726d46f93632c93632 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:48:10 +0530 Subject: [PATCH 458/469] Updated formatting --- .../utils/sources/MPPEmbedding+Helpers.h | 3 ++- .../utils/sources/MPPEmbedding+Helpers.mm | 15 ++++++++------- .../utils/sources/MPPEmbeddingResult+Helpers.h | 4 +++- .../utils/sources/MPPEmbeddingResult+Helpers.mm | 10 +++++----- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h index 9ff5455d9..33fb3839d 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h @@ -19,7 +19,8 @@ NS_ASSUME_NONNULL_BEGIN @interface MPPEmbedding (Helpers) -+ (MPPEmbedding *)embeddingWithProto:(const ::mediapipe::tasks::components::containers::proto::Embedding &)embeddingProto; ++ (MPPEmbedding *)embeddingWithProto: + (const ::mediapipe::tasks::components::containers::proto::Embedding &)embeddingProto; @end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm index 4b2acff7c..a676242e8 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm @@ -31,7 +31,8 @@ using EmbeddingProto = ::mediapipe::tasks::components::containers::proto::Embedd NSData *quantizedEmbedding; if (embeddingProto.has_float_embedding()) { - floatEmbedding = [NSMutableArray arrayWithCapacity:embeddingProto.float_embedding().values_size()]; + floatEmbedding = + [NSMutableArray arrayWithCapacity:embeddingProto.float_embedding().values_size()]; const auto floatEmbeddingValues = embeddingProto.float_embedding().values(); for (const auto value : embeddingProto.float_embedding().values()) { @@ -40,11 +41,11 @@ using EmbeddingProto = ::mediapipe::tasks::components::containers::proto::Embedd } if (embeddingProto.has_quantized_embedding()) { - const std::string& cppQuantizedEmbedding = - embeddingProto.quantized_embedding().values().data(); + const std::string &cppQuantizedEmbedding = embeddingProto.quantized_embedding().values().data(); const char *cppQuantizedEmbeddingCString = cppQuantizedEmbedding.c_str(); - quantizedEmbedding = [NSData dataWithBytes:cppQuantizedEmbeddingCString length:sizeof(cppQuantizedEmbeddingCString)]; + quantizedEmbedding = [NSData dataWithBytes:cppQuantizedEmbeddingCString + length:sizeof(cppQuantizedEmbeddingCString)]; } NSString *headName; @@ -54,9 +55,9 @@ using EmbeddingProto = ::mediapipe::tasks::components::containers::proto::Embedd } return [[MPPEmbedding alloc] initWithFloatEmbedding:floatEmbedding - quantizedEmbedding:quantizedEmbedding - headIndex:embeddingProto.head_index() - headName:headName]; + quantizedEmbedding:quantizedEmbedding + headIndex:embeddingProto.head_index() + headName:headName]; } @end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h index 6ec26b764..cc53c3e25 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h @@ -19,7 +19,9 @@ NS_ASSUME_NONNULL_BEGIN @interface MPPEmbeddingResult (Helpers) -+ (MPPEmbeddingResult *)embeddingResultWithProto:(const ::mediapipe::tasks::components::containers::proto::EmbeddingResult &)embeddingResultProto; ++ (MPPEmbeddingResult *)embeddingResultWithProto: + (const ::mediapipe::tasks::components::containers::proto::EmbeddingResult &) + embeddingResultProto; @end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm index 385b74efb..f9863e9ca 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm @@ -23,9 +23,10 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto:: @implementation MPPEmbeddingResult (Helpers) -+ (MPPEmbeddingResult *)embeddingResultWithProto:(const EmbeddingResultProto &)embeddingResultProto { - NSMutableArray *embeddings = [NSMutableArray - arrayWithCapacity:(NSUInteger)embeddingResultProto.embeddings_size()]; ++ (MPPEmbeddingResult *)embeddingResultWithProto: + (const EmbeddingResultProto &)embeddingResultProto { + NSMutableArray *embeddings = + [NSMutableArray arrayWithCapacity:(NSUInteger)embeddingResultProto.embeddings_size()]; for (const auto &embeddingProto : embeddingResultProto.embeddings()) { [embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]]; } @@ -35,8 +36,7 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto:: timestampMs = (NSInteger)embeddingResultProto.timestamp_ms(); } - return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings - timestampMs:timestampMs]; + return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs]; } @end From 42e712e911724fffe79314f00686ab019d89a0f5 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:50:14 +0530 Subject: [PATCH 459/469] Added MPPTextEmbedderOptionsHelpers --- .../tasks/ios/text/text_embedder/utils/BUILD | 33 ++++++++++++++ .../sources/MPPTextEmbedderOptions+Helpers.h | 27 ++++++++++++ .../sources/MPPTextEmbedderOptions+Helpers.mm | 44 +++++++++++++++++++ 3 files changed, 104 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_embedder/utils/BUILD create mode 100644 mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/BUILD b/mediapipe/tasks/ios/text/text_embedder/utils/BUILD new file mode 100644 index 000000000..0c41298ac --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/BUILD @@ -0,0 +1,33 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextEmbedderOptionsHelpers", + srcs = ["sources/MPPTextEmbedderOptions+Helpers.mm"], + hdrs = ["sources/MPPTextEmbedderOptions+Helpers.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedderOptions", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h new file mode 100644 index 000000000..7f3d1c958 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h @@ -0,0 +1,27 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextEmbedderOptions (Helpers) + +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm new file mode 100644 index 000000000..c2450f955 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using TextEmbedderGraphOptionsProto = + ::mediapipe::tasks::text::text_embedder::proto::TextEmbedderGraphOptions; +using EmbedderOptionsProto = ::mediapipe::tasks::components::processors::proto::EmbedderOptions; +} // namespace + +@implementation MPPTextEmbedderOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + TextEmbedderGraphOptionsProto *graphOptions = + optionsProto->MutableExtension(TextEmbedderGraphOptionsProto::ext); + [self.baseOptions copyToProto:graphOptions->mutable_base_options()]; + + EmbedderOptionsProto *embedderOptionsProto = graphOptions->mutable_embedder_options(); + embedderOptionsProto->Clear(); + + embedderOptionsProto->set_l2_normalize(self.l2Normalize ? true : false); + embedderOptionsProto->set_quantize(self.quantize ? true : false); +} + +@end From 38eac174e9a36dbb67f6c2e33d9ca984ba27d3ad Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:51:01 +0530 Subject: [PATCH 460/469] Added MPPTextEmbedderResultHelpers --- .../tasks/ios/text/text_embedder/utils/BUILD | 11 +++++ .../sources/MPPTextEmbedderResult+Helpers.h | 28 +++++++++++++ .../sources/MPPTextEmbedderResult+Helpers.mm | 42 +++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/BUILD b/mediapipe/tasks/ios/text/text_embedder/utils/BUILD index 0c41298ac..eeb4981fb 100644 --- a/mediapipe/tasks/ios/text/text_embedder/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/utils/BUILD @@ -31,3 +31,14 @@ objc_library( ], ) +objc_library( + name = "MPPTextEmbedderResultHelpers", + srcs = ["sources/MPPTextEmbedderResult+Helpers.mm"], + hdrs = ["sources/MPPTextEmbedderResult+Helpers.h"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/ios/components/containers/utils:MPPEmbeddingResultHelpers", + "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedderResult", + ], +) diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h new file mode 100644 index 000000000..de899a103 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextEmbedderResult (Helpers) + ++ (MPPTextEmbedderResult *)textEmbedderResultWithOutputPacket: + (const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm new file mode 100644 index 000000000..bba1b345e --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm @@ -0,0 +1,42 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using EmbeddingResultProto = + ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::Packet; +} // namespace + +#define int kMicroSecondsPerMilliSecond = 1000; + +@implementation MPPTextEmbedderResult (Helpers) + ++ (MPPTextEmbedderResult *)textEmbedderResultWithOutputPacket:(const Packet &)packet { + MPPEmbeddingResult *embeddingResult = [MPPEmbeddingResult + embeddingResultWithProto:packet.Get()]; + + return [[MPPTextEmbedderResult alloc] + initWithEmbeddingResult:embeddingResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + +@end From d588f73a6d84d1ef5c36da72442b0db5a01b85dd Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:51:30 +0530 Subject: [PATCH 461/469] Added MPPTextEmbedder --- mediapipe/tasks/ios/text/text_embedder/BUILD | 26 +++++ .../text_embedder/sources/MPPTextEmbedder.h | 91 ++++++++++++++++++ .../text_embedder/sources/MPPTextEmbedder.mm | 96 +++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD index 143f0a587..21226b012 100644 --- a/mediapipe/tasks/ios/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -32,3 +32,29 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) + +objc_library( + name = "MPPTextEmbedder", + srcs = ["sources/MPPTextEmbedder.mm"], + hdrs = ["sources/MPPTextEmbedder.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPTextEmbedder", + deps = [ + ":MPPTextEmbedderOptions", + ":MPPTextEmbedderResult", + "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderOptionsHelpers", + "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderResultHelpers", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h new file mode 100644 index 000000000..d1deb60ed --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -0,0 +1,91 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs embedding extraction on text. + * + * This API expects a TFLite model with (optional) [TFLite Model Metadata](https://www.tensorflow.org/lite/convert/metadata"). + * + * Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + * Input tensors + * - Three input tensors `kTfLiteInt32` of shape `[batch_size x bert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires + * a Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size x max_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor (`kTfLiteString`) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor (`kTfLiteFloat32`/`kTfLiteUint8`) with shape `[1 x N]` where `N` is the number of dimensions in the produced embeddings. + */ +NS_SWIFT_NAME(TextEmbedder) +@interface MPPTextEmbedder : NSObject + +/** + * Creates a new instance of `MPPTextEmbedder` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `MPPTextEmbedderOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * text embedder. + * + * @return A new instance of `MPPTextEmbedder` with the given model path. `nil` if there is an + * error in initializing the text embedder. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPTextEmbedder` from the given `MPPTextEmbedderOptions`. + * + * @param options The options of type `MPPTextEmbedderOptions` to use for configuring the + * `MPPTextEmbedder. + * @param error An optional error parameter populated when there is an error in initializing the + * text embedder. + * + * @return A new instance of `MPPTextEmbedder` with the given options. `nil` if there is an + * error in initializing the text embedder. + */ +- (nullable instancetype)initWithOptions:(MPPTextEmbedderOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs embedding extraction on the input text. + * + * @param text The `NSString` on which embedding extraction is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * embedding extraction on the input text. + * + * @return A `MPPTextEmbedderResult` object that contains a list of embeddings. + */ +- (nullable MPPTextEmbedderResult *)embedText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(embed(text:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm new file mode 100644 index 000000000..395ce28f6 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -0,0 +1,96 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" + +#include "absl/status/statusor.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kEmbeddingsOutStreamName = @"embeddings_out"; +static NSString *const kEmbeddingsTag = @"EMBEDDINGS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + +@interface MPPTextEmbedder () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPTextEmbedder + +- (instancetype)initWithOptions:(MPPTextEmbedderOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kEmbeddingsTag, + kEmbeddingsOutStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPTextEmbedderOptions *options = [[MPPTextEmbedderOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPTextEmbedderResult *)embedText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + + if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + return nil; + } + + return [MPPTextEmbedderResult + textEmbedderResultWithOutputPacket:statusOrOutputPacketMap.value() + [kEmbeddingsOutStreamName.cppString]]; +} + +@end From 85c310d01c1a2cd8acc376493a7879e53ca62b37 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:52:55 +0530 Subject: [PATCH 462/469] Updated formatting --- .../ios/text/text_embedder/sources/MPPTextEmbedder.h | 8 +++++--- .../ios/text/text_embedder/sources/MPPTextEmbedder.mm | 4 ++-- .../utils/sources/MPPTextEmbedderOptions+Helpers.mm | 2 +- .../utils/sources/MPPTextEmbedderResult+Helpers.h | 3 +-- .../utils/sources/MPPTextEmbedderResult+Helpers.mm | 11 +++++------ 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h index d1deb60ed..a45ab6747 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -23,7 +23,8 @@ NS_ASSUME_NONNULL_BEGIN /** * @brief Performs embedding extraction on text. * - * This API expects a TFLite model with (optional) [TFLite Model Metadata](https://www.tensorflow.org/lite/convert/metadata"). + * This API expects a TFLite model with (optional) [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata"). * * Metadata is required for models with int32 input tensors because it contains the input process * unit for the model's Tokenizer. No metadata is required for models with string input tensors. @@ -38,7 +39,8 @@ NS_ASSUME_NONNULL_BEGIN * - Or one input tensor (`kTfLiteString`) that is shapeless or has shape `[1]` containing * the input string. * - * At least one output tensor (`kTfLiteFloat32`/`kTfLiteUint8`) with shape `[1 x N]` where `N` is the number of dimensions in the produced embeddings. + * At least one output tensor (`kTfLiteFloat32`/`kTfLiteUint8`) with shape `[1 x N]` where `N` is + * the number of dimensions in the produced embeddings. */ NS_SWIFT_NAME(TextEmbedder) @interface MPPTextEmbedder : NSObject @@ -80,7 +82,7 @@ NS_SWIFT_NAME(TextEmbedder) * @return A `MPPTextEmbedderResult` object that contains a list of embeddings. */ - (nullable MPPTextEmbedderResult *)embedText:(NSString *)text - error:(NSError **)error NS_SWIFT_NAME(embed(text:)); + error:(NSError **)error NS_SWIFT_NAME(embed(text:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm index 395ce28f6..a9c811cdb 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -89,8 +89,8 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex } return [MPPTextEmbedderResult - textEmbedderResultWithOutputPacket:statusOrOutputPacketMap.value() - [kEmbeddingsOutStreamName.cppString]]; + textEmbedderResultWithOutputPacket:statusOrOutputPacketMap + .value()[kEmbeddingsOutStreamName.cppString]]; } @end diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm index c2450f955..e17b6e8da 100644 --- a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm @@ -17,8 +17,8 @@ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" -#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" namespace { using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h index de899a103..0a808a54b 100644 --- a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h @@ -20,8 +20,7 @@ NS_ASSUME_NONNULL_BEGIN @interface MPPTextEmbedderResult (Helpers) -+ (MPPTextEmbedderResult *)textEmbedderResultWithOutputPacket: - (const mediapipe::Packet &)packet; ++ (MPPTextEmbedderResult *)textEmbedderResultWithOutputPacket:(const mediapipe::Packet &)packet; @end diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm index bba1b345e..b769292ce 100644 --- a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm @@ -20,8 +20,7 @@ static const int kMicroSecondsPerMilliSecond = 1000; namespace { -using EmbeddingResultProto = - ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::Packet; } // namespace @@ -30,13 +29,13 @@ using ::mediapipe::Packet; @implementation MPPTextEmbedderResult (Helpers) + (MPPTextEmbedderResult *)textEmbedderResultWithOutputPacket:(const Packet &)packet { - MPPEmbeddingResult *embeddingResult = [MPPEmbeddingResult - embeddingResultWithProto:packet.Get()]; + MPPEmbeddingResult *embeddingResult = + [MPPEmbeddingResult embeddingResultWithProto:packet.Get()]; return [[MPPTextEmbedderResult alloc] initWithEmbeddingResult:embeddingResult - timestampMs:(NSInteger)(packet.Timestamp().Value() / - kMicroSecondsPerMilliSecond)]; + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; } @end From bd507b2d7b4c2915c6150f96f7865e89a66ef78c Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 19:27:05 +0530 Subject: [PATCH 463/469] Updated MPPEmbeddingHelpers to reflect type change of quantized embeddings --- .../containers/utils/sources/MPPEmbedding+Helpers.mm | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm index a676242e8..faf490901 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm @@ -28,12 +28,11 @@ using EmbeddingProto = ::mediapipe::tasks::components::containers::proto::Embedd NSString *displayName; NSMutableArray *floatEmbedding; - NSData *quantizedEmbedding; + NSMutableArray *quantizedEmbedding; if (embeddingProto.has_float_embedding()) { floatEmbedding = [NSMutableArray arrayWithCapacity:embeddingProto.float_embedding().values_size()]; - const auto floatEmbeddingValues = embeddingProto.float_embedding().values(); for (const auto value : embeddingProto.float_embedding().values()) { [floatEmbedding addObject:[NSNumber numberWithFloat:value]]; @@ -41,11 +40,12 @@ using EmbeddingProto = ::mediapipe::tasks::components::containers::proto::Embedd } if (embeddingProto.has_quantized_embedding()) { - const std::string &cppQuantizedEmbedding = embeddingProto.quantized_embedding().values().data(); + const std::string &cppQuantizedEmbedding = embeddingProto.quantized_embedding().values(); + quantizedEmbedding = [NSMutableArray arrayWithCapacity:cppQuantizedEmbedding.length()]; - const char *cppQuantizedEmbeddingCString = cppQuantizedEmbedding.c_str(); - quantizedEmbedding = [NSData dataWithBytes:cppQuantizedEmbeddingCString - length:sizeof(cppQuantizedEmbeddingCString)]; + for (char ch : cppQuantizedEmbedding) { + [quantizedEmbedding addObject:[NSNumber numberWithChar:ch]]; + } } NSString *headName; From 3ee377f671d07e7a234ca0050050b7395e678c41 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 1 Feb 2023 07:39:51 -0800 Subject: [PATCH 464/469] Internal change PiperOrigin-RevId: 506312863 --- mediapipe/framework/formats/BUILD | 5 +- mediapipe/framework/formats/frame_buffer.cc | 176 -------------------- mediapipe/framework/formats/frame_buffer.h | 147 ++++------------ 3 files changed, 32 insertions(+), 296 deletions(-) delete mode 100644 mediapipe/framework/formats/frame_buffer.cc diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index abd530b46..989ee18f0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -489,12 +489,9 @@ cc_test( cc_library( name = "frame_buffer", - srcs = ["frame_buffer.cc"], hdrs = ["frame_buffer.h"], deps = [ "//mediapipe/framework/port:integral_types", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/log:check", ], ) diff --git a/mediapipe/framework/formats/frame_buffer.cc b/mediapipe/framework/formats/frame_buffer.cc deleted file mode 100644 index a86d3f2ad..000000000 --- a/mediapipe/framework/formats/frame_buffer.cc +++ /dev/null @@ -1,176 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/framework/formats/frame_buffer.h" - -#include "absl/status/status.h" -#include "absl/status/statusor.h" - -namespace mediapipe { - -namespace { - -// Returns whether the input `format` is a supported YUV format. -bool IsSupportedYuvFormat(FrameBuffer::Format format) { - return format == FrameBuffer::Format::kNV21 || - format == FrameBuffer::Format::kNV12 || - format == FrameBuffer::Format::kYV12 || - format == FrameBuffer::Format::kYV21; -} - -// Returns supported 1-plane FrameBuffer in YuvData structure. -absl::StatusOr GetYuvDataFromOnePlaneFrameBuffer( - const FrameBuffer& source) { - if (!IsSupportedYuvFormat(source.format())) { - return absl::InvalidArgumentError( - "The source FrameBuffer format is not part of YUV420 family."); - } - - FrameBuffer::YuvData result; - const int y_buffer_size = - source.plane(0).stride.row_stride_bytes * source.dimension().height; - const int uv_buffer_size = - ((source.plane(0).stride.row_stride_bytes + 1) / 2) * - ((source.dimension().height + 1) / 2); - result.y_buffer = source.plane(0).buffer; - result.y_row_stride = source.plane(0).stride.row_stride_bytes; - result.uv_row_stride = result.y_row_stride; - - if (source.format() == FrameBuffer::Format::kNV21) { - result.v_buffer = result.y_buffer + y_buffer_size; - result.u_buffer = result.v_buffer + 1; - result.uv_pixel_stride = 2; - // If y_row_stride equals to the frame width and is an odd value, - // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride. - if (result.y_row_stride == source.dimension().width && - result.y_row_stride % 2 == 1) { - result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2; - } - } else if (source.format() == FrameBuffer::Format::kNV12) { - result.u_buffer = result.y_buffer + y_buffer_size; - result.v_buffer = result.u_buffer + 1; - result.uv_pixel_stride = 2; - // If y_row_stride equals to the frame width and is an odd value, - // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride. - if (result.y_row_stride == source.dimension().width && - result.y_row_stride % 2 == 1) { - result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2; - } - } else if (source.format() == FrameBuffer::Format::kYV21) { - result.u_buffer = result.y_buffer + y_buffer_size; - result.v_buffer = result.u_buffer + uv_buffer_size; - result.uv_pixel_stride = 1; - result.uv_row_stride = (result.y_row_stride + 1) / 2; - } else if (source.format() == FrameBuffer::Format::kYV12) { - result.v_buffer = result.y_buffer + y_buffer_size; - result.u_buffer = result.v_buffer + uv_buffer_size; - result.uv_pixel_stride = 1; - result.uv_row_stride = (result.y_row_stride + 1) / 2; - } - return result; -} - -// Returns supported 2-plane FrameBuffer in YuvData structure. -absl::StatusOr GetYuvDataFromTwoPlaneFrameBuffer( - const FrameBuffer& source) { - if (source.format() != FrameBuffer::Format::kNV12 && - source.format() != FrameBuffer::Format::kNV21) { - return absl::InvalidArgumentError("Unsupported YUV planar format."); - } - - FrameBuffer::YuvData result; - // Y plane - result.y_buffer = source.plane(0).buffer; - // All plane strides - result.y_row_stride = source.plane(0).stride.row_stride_bytes; - result.uv_row_stride = source.plane(1).stride.row_stride_bytes; - result.uv_pixel_stride = 2; - - if (source.format() == FrameBuffer::Format::kNV12) { - // Y and UV interleaved format - result.u_buffer = source.plane(1).buffer; - result.v_buffer = result.u_buffer + 1; - } else { - // Y and VU interleaved format - result.v_buffer = source.plane(1).buffer; - result.u_buffer = result.v_buffer + 1; - } - return result; -} - -// Returns supported 3-plane FrameBuffer in YuvData structure. Note that NV21 -// and NV12 are included in the supported Yuv formats. Technically, NV21 and -// NV12 should not be described by the 3-plane format. Historically, NV21 is -// used loosely such that it can also be used to describe YV21 format. For -// backwards compatibility, FrameBuffer supports NV21/NV12 with 3-plane format -// but such usage is discouraged -absl::StatusOr GetYuvDataFromThreePlaneFrameBuffer( - const FrameBuffer& source) { - if (!IsSupportedYuvFormat(source.format())) { - return absl::InvalidArgumentError( - "The source FrameBuffer format is not part of YUV420 family."); - } - - if (source.plane(1).stride.row_stride_bytes != - source.plane(2).stride.row_stride_bytes || - source.plane(1).stride.pixel_stride_bytes != - source.plane(2).stride.pixel_stride_bytes) { - return absl::InternalError("Unsupported YUV planar format."); - } - FrameBuffer::YuvData result; - if (source.format() == FrameBuffer::Format::kNV21 || - source.format() == FrameBuffer::Format::kYV12) { - // Y follow by VU order. The VU chroma planes can be interleaved or - // planar. - result.y_buffer = source.plane(0).buffer; - result.v_buffer = source.plane(1).buffer; - result.u_buffer = source.plane(2).buffer; - result.y_row_stride = source.plane(0).stride.row_stride_bytes; - result.uv_row_stride = source.plane(1).stride.row_stride_bytes; - result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes; - } else { - // Y follow by UV order. The UV chroma planes can be interleaved or - // planar. - result.y_buffer = source.plane(0).buffer; - result.u_buffer = source.plane(1).buffer; - result.v_buffer = source.plane(2).buffer; - result.y_row_stride = source.plane(0).stride.row_stride_bytes; - result.uv_row_stride = source.plane(1).stride.row_stride_bytes; - result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes; - } - return result; -} - -} // namespace - -absl::StatusOr FrameBuffer::GetYuvDataFromFrameBuffer( - const FrameBuffer& source) { - if (!IsSupportedYuvFormat(source.format())) { - return absl::InvalidArgumentError( - "The source FrameBuffer format is not part of YUV420 family."); - } - - if (source.plane_count() == 1) { - return GetYuvDataFromOnePlaneFrameBuffer(source); - } else if (source.plane_count() == 2) { - return GetYuvDataFromTwoPlaneFrameBuffer(source); - } else if (source.plane_count() == 3) { - return GetYuvDataFromThreePlaneFrameBuffer(source); - } - return absl::InvalidArgumentError( - "The source FrameBuffer must be consisted by 1, 2, or 3 planes"); -} - -} // namespace mediapipe diff --git a/mediapipe/framework/formats/frame_buffer.h b/mediapipe/framework/formats/frame_buffer.h index 7578a0121..ccc699724 100644 --- a/mediapipe/framework/formats/frame_buffer.h +++ b/mediapipe/framework/formats/frame_buffer.h @@ -16,14 +16,9 @@ limitations under the License. #ifndef MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ #define MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ -#include -#include -#include -#include #include -#include "absl/memory/memory.h" -#include "absl/status/statusor.h" +#include "absl/log/check.h" #include "mediapipe/framework/port/integral_types.h" namespace mediapipe { @@ -36,19 +31,16 @@ namespace mediapipe { // Examples: // // // Create an metadata instance with no backing buffer. -// auto buffer = FrameBuffer::Create(/*planes=*/{}, dimension, kRGBA, -// KTopLeft); +// FrameBuffer buffer{/*planes=*/{}, dimension, kRGBA}; // // // Create an RGBA instance with backing buffer on single plane. -// FrameBuffer::Plane plane = -// {rgba_buffer, /*stride=*/{dimension.width * 4, 4}}; -// auto buffer = FrameBuffer::Create({plane}, dimension, kRGBA, kTopLeft); +// FrameBuffer::Plane plane{rgba_buffer, /*stride=*/{dimension.width * 4, 4}}; +// FrameBuffer buffer{{plane}, dimension, kRGBA, kTopLeft)}; // // // Create an YUV instance with planar backing buffer. -// FrameBuffer::Plane y_plane = {y_buffer, /*stride=*/{dimension.width , 1}}; -// FrameBuffer::Plane uv_plane = {u_buffer, /*stride=*/{dimension.width, 2}}; -// auto buffer = FrameBuffer::Create({y_plane, uv_plane}, dimension, kNV21, -// kLeftTop); +// FrameBuffer::Plane y_plane{y_buffer, /*stride=*/{dimension.width , 1}}; +// FrameBuffer::Plane uv_plane{u_buffer, /*stride=*/{dimension.width, 2}}; +// FrameBuffer buffer{{y_plane, uv_plane}, dimension, kNV21}; class FrameBuffer { public: // Colorspace formats. @@ -81,39 +73,16 @@ class FrameBuffer { bool operator!=(const Stride& other) const { return !operator==(other); } }; - // YUV data structure. - struct YuvData { - const uint8* y_buffer; - const uint8* u_buffer; - const uint8* v_buffer; - // Y buffer row stride in bytes. - int y_row_stride; - // U/V buffer row stride in bytes. - int uv_row_stride; - // U/V pixel stride in bytes. This is the distance between two consecutive - // u/v pixel values in a row. - int uv_pixel_stride; - }; - - // FrameBuffer content orientation follows EXIF specification. The name of - // each enum value defines the position of the 0th row and the 0th column of - // the image content. See http://jpegclub.org/exif_orientation.html for - // details. - enum class Orientation { - kTopLeft = 1, - kTopRight = 2, - kBottomRight = 3, - kBottomLeft = 4, - kLeftTop = 5, - kRightTop = 6, - kRightBottom = 7, - kLeftBottom = 8 - }; - // Plane encapsulates buffer and stride information. struct Plane { - const uint8* buffer; - Stride stride; + Plane(uint8* buffer, Stride stride) : buffer_(buffer), stride_(stride) {} + const uint8* buffer() const { return buffer_; } + uint8* mutable_buffer() { return buffer_; } + Stride stride() const { return stride_; } + + private: + uint8* buffer_; + Stride stride_; }; // Dimension information for the whole frame or a cropped portion of it. @@ -149,80 +118,30 @@ class FrameBuffer { int Size() const { return width * height; } }; - // Factory method for creating a FrameBuffer object from row-major backing - // buffers. - static std::unique_ptr Create(const std::vector& planes, - Dimension dimension, Format format, - Orientation orientation) { - return absl::make_unique(planes, dimension, format, - orientation); - } - - // Factory method for creating a FrameBuffer object from row-major movable - // backing buffers. - static std::unique_ptr Create(std::vector&& planes, - Dimension dimension, Format format, - Orientation orientation) { - return absl::make_unique(std::move(planes), dimension, format, - orientation); - } - - // Returns YuvData which contains the Y, U, and V buffer and their - // stride info from the input `source` FrameBuffer which is in the YUV family - // formats (e.g NV12, NV21, YV12, and YV21). - static absl::StatusOr GetYuvDataFromFrameBuffer( - const FrameBuffer& source); - // Builds a FrameBuffer object from a row-major backing buffer. // - // The FrameBuffer does not take ownership of the backing buffer. The backing - // buffer is read-only and the caller is responsible for maintaining the - // backing buffer lifecycle for the lifetime of FrameBuffer. + // The FrameBuffer does not take ownership of the backing buffer. The caller + // is responsible for maintaining the backing buffer lifecycle for the + // lifetime of FrameBuffer. FrameBuffer(const std::vector& planes, Dimension dimension, - Format format, Orientation orientation) - : planes_(planes), - dimension_(dimension), - format_(format), - orientation_(orientation) {} - - // Builds a FrameBuffer object from a movable row-major backing buffer. - // - // The FrameBuffer does not take ownership of the backing buffer. The backing - // buffer is read-only and the caller is responsible for maintaining the - // backing buffer lifecycle for the lifetime of FrameBuffer. - FrameBuffer(std::vector&& planes, Dimension dimension, Format format, - Orientation orientation) - : planes_(std::move(planes)), - dimension_(dimension), - format_(format), - orientation_(orientation) {} - - // Copy constructor. - // - // FrameBuffer does not take ownership of the backing buffer. The copy - // constructor behaves the same way to only copy the buffer pointer and not - // take ownership of the backing buffer. - FrameBuffer(const FrameBuffer& frame_buffer) { - planes_.clear(); - for (int i = 0; i < frame_buffer.plane_count(); i++) { - planes_.push_back( - FrameBuffer::Plane{.buffer = frame_buffer.plane(i).buffer, - .stride = frame_buffer.plane(i).stride}); - } - dimension_ = frame_buffer.dimension(); - format_ = frame_buffer.format(); - orientation_ = frame_buffer.orientation(); - } + Format format) + : planes_(planes), dimension_(dimension), format_(format) {} // Returns number of planes. int plane_count() const { return planes_.size(); } // Returns plane indexed by the input `index`. - Plane plane(int index) const { - if (index > -1 && static_cast(index) < planes_.size()) { - return planes_[index]; - } - return {}; + const Plane& plane(int index) const { + CHECK_GE(index, 0); + CHECK_LT(static_cast(index), planes_.size()); + return planes_[index]; + } + + // Returns mutable plane indexed by the input `index`. + Plane mutable_plane(int index) { + CHECK_GE(index, 0); + CHECK_LT(static_cast(index), planes_.size()); + return planes_[index]; } // Returns FrameBuffer dimension. @@ -231,14 +150,10 @@ class FrameBuffer { // Returns FrameBuffer format. Format format() const { return format_; } - // Returns FrameBuffer orientation. - Orientation orientation() const { return orientation_; } - private: std::vector planes_; Dimension dimension_; Format format_; - Orientation orientation_; }; } // namespace mediapipe From 0919a6c0a363115f7fad783e8eab09790e6f61a7 Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Wed, 1 Feb 2023 09:43:28 -0800 Subject: [PATCH 465/469] Log which InferenceCalculator impl is used at runtime. PiperOrigin-RevId: 506343086 --- mediapipe/calculators/tensor/inference_calculator.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 4ccdc07e1..2a6936eba 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -63,6 +63,10 @@ class InferenceCalculatorSelectorImpl for (const auto& suffix : impls) { const auto impl = absl::StrCat("InferenceCalculator", suffix); if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue; + VLOG(1) << "Using " << suffix << " for InferenceCalculator with " + << (options.has_model_path() + ? "model " + options.model_path() + : "output_stream " + subgraph_node.output_stream(0)); CalculatorGraphConfig::Node impl_node = subgraph_node; impl_node.set_calculator(impl); return tool::MakeSingleNodeGraph(std::move(impl_node)); From 83e33b4dbe580a0879987925d62b845e97397c0f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 1 Feb 2023 09:53:00 -0800 Subject: [PATCH 466/469] Internal change PiperOrigin-RevId: 506345436 --- mediapipe/tasks/cc/core/BUILD | 13 ++-- .../tasks/cc/core/external_file_handler.cc | 77 +++++++++++++++---- 2 files changed, 72 insertions(+), 18 deletions(-) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index d440271df..e5bc18306 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -40,16 +40,19 @@ cc_library( srcs = ["external_file_handler.cc"], hdrs = ["external_file_handler.h"], deps = [ - "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - ], + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + ] + select({ + "//mediapipe:windows": ["@bazel_tools//tools/cpp/runfiles"], + "//conditions:default": [], + }), ) cc_library( diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index ff30bea72..a95b8e744 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -37,12 +37,17 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#ifdef _WIN32 +#include "tools/cpp/runfiles/runfiles.h" +#endif // _WIN32 + namespace mediapipe { namespace tasks { namespace core { @@ -50,13 +55,21 @@ namespace { using ::absl::StatusCode; +#ifndef O_BINARY +#ifdef _O_BINARY +#define O_BINARY _O_BINARY +#else +#define O_BINARY 0 // If this isn't defined, the platform doesn't need it. +#endif // _O_BINARY +#endif // O_BINARY + // Gets the offset aligned to page size for mapping given files into memory by // file descriptor correctly, as according to mmap(2), the offset used in mmap // must be a multiple of sysconf(_SC_PAGE_SIZE). int64 GetPageSizeAlignedOffset(int64 offset) { #ifdef _WIN32 // mmap is not used on Windows - return -1; + return 0; #else int64 aligned_offset = offset; int64 page_size = sysconf(_SC_PAGE_SIZE); @@ -64,7 +77,7 @@ int64 GetPageSizeAlignedOffset(int64 offset) { aligned_offset = offset / page_size * page_size; } return aligned_offset; -#endif +#endif // _WIN32 } } // namespace @@ -83,6 +96,24 @@ ExternalFileHandler::CreateFromExternalFile( return handler; } +absl::StatusOr PathToResourceAsFile(std::string path) { +#ifndef _WIN32 + return path; +#else + if (absl::StartsWith(path, "./")) { + path = "mediapipe" + path.substr(1); + } + + std::string error; + std::unique_ptr<::bazel::tools::cpp::runfiles::Runfiles> runfiles( + ::bazel::tools::cpp::runfiles::Runfiles::Create("", &error)); + if (!runfiles) { + return absl::InternalError("Unable to initialize runfiles: " + error); + } + return runfiles->Rlocation(path); +#endif // _WIN32 +} + absl::Status ExternalFileHandler::MapExternalFile() { if (!external_file_.file_content().empty()) { return absl::OkStatus(); @@ -101,12 +132,6 @@ absl::Status ExternalFileHandler::MapExternalFile() { return absl::OkStatus(); } -// TODO: Add Windows support -#ifdef _WIN32 - return CreateStatusWithPayload(StatusCode::kFailedPrecondition, - "File loading is not yet supported on Windows", - MediaPipeTasksStatus::kFileReadError); -#else if (external_file_.file_name().empty() && !external_file_.has_file_descriptor_meta()) { return CreateStatusWithPayload( @@ -118,7 +143,9 @@ absl::Status ExternalFileHandler::MapExternalFile() { // Obtain file descriptor, offset and size. int fd = -1; if (!external_file_.file_name().empty()) { - owned_fd_ = open(external_file_.file_name().c_str(), O_RDONLY); + ASSIGN_OR_RETURN(std::string file_name, + PathToResourceAsFile(external_file_.file_name())); + owned_fd_ = open(file_name.c_str(), O_RDONLY | O_BINARY); if (owned_fd_ < 0) { const std::string error_message = absl::StrFormat( "Unable to open file at %s", external_file_.file_name()); @@ -149,6 +176,12 @@ absl::Status ExternalFileHandler::MapExternalFile() { } fd = owned_fd_; } else { +#ifdef _WIN32 + return CreateStatusWithPayload( + StatusCode::kFailedPrecondition, + "File descriptors are not supported on Windows.", + MediaPipeTasksStatus::kFileReadError); +#else fd = external_file_.file_descriptor_meta().fd(); if (fd < 0) { return CreateStatusWithPayload( @@ -158,6 +191,7 @@ absl::Status ExternalFileHandler::MapExternalFile() { } buffer_offset_ = external_file_.file_descriptor_meta().offset(); buffer_size_ = external_file_.file_descriptor_meta().length(); +#endif // _WIN32 } // Get actual file size. Always use 0 as offset to lseek(2) to get the actual // file size, as SEEK_END returns the size of the file *plus* offset. @@ -189,22 +223,37 @@ absl::Status ExternalFileHandler::MapExternalFile() { buffer_size_ + buffer_offset_, file_size), MediaPipeTasksStatus::kInvalidArgumentError); } + // If buffer_offset_ is not multiple of sysconf(_SC_PAGE_SIZE), align with // extra leading bytes and adjust buffer_size_ to account for the extra // leading bytes. buffer_aligned_offset_ = GetPageSizeAlignedOffset(buffer_offset_); buffer_aligned_size_ = buffer_size_ + buffer_offset_ - buffer_aligned_offset_; + +#ifdef _WIN32 + buffer_ = malloc(file_size); + // Return the file pointer back to the beginning of the file + lseek(fd, 0L, SEEK_SET); + buffer_size_ = read(fd, buffer_, file_size); + if (buffer_size_ <= 0) { + free(buffer_); + buffer_ = nullptr; + } +#else // Map into memory. buffer_ = mmap(/*addr=*/nullptr, buffer_aligned_size_, PROT_READ, MAP_SHARED, fd, buffer_aligned_offset_); if (buffer_ == MAP_FAILED) { + buffer_ = nullptr; + } +#endif // _WIN32 + if (!buffer_) { return CreateStatusWithPayload( StatusCode::kUnknown, absl::StrFormat("Unable to map file to memory buffer, errno=%d", errno), MediaPipeTasksStatus::kFileMmapError); } return absl::OkStatus(); -#endif } absl::string_view ExternalFileHandler::GetFileContent() { @@ -223,11 +272,13 @@ absl::string_view ExternalFileHandler::GetFileContent() { } ExternalFileHandler::~ExternalFileHandler() { -#ifndef _WIN32 - if (buffer_ != MAP_FAILED) { + if (buffer_) { +#ifdef _WIN32 + free(buffer_); +#else munmap(buffer_, buffer_aligned_size_); +#endif // _WIN32 } -#endif if (owned_fd_ >= 0) { close(owned_fd_); } From 0f3cf9c56a5d1082482dccbbbcbc4bfde1639e7e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 1 Feb 2023 11:02:58 -0800 Subject: [PATCH 467/469] Add "noasan" to MPPTextClassifierObjcTest PiperOrigin-RevId: 506366650 --- mediapipe/tasks/ios/test/text/text_classifier/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD index 97ec2aa8b..3b533646e 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -49,6 +49,7 @@ ios_unit_test( name = "MPPTextClassifierObjcTest", minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, deps = [ ":MPPTextClassifierObjcTestLibrary", ], From 286dde97ad82e3ae17896ee10438ae59221ebdca Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 1 Feb 2023 14:20:12 -0800 Subject: [PATCH 468/469] Make TextEmbedder and TextClassifier tests pass on Windows PiperOrigin-RevId: 506421383 --- .../text_classifier/text_classifier_test.cc | 43 ++++++++++++++----- .../text/text_embedder/text_embedder_test.cc | 14 +++++- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 799885eac..71f7b1f2d 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -135,28 +135,46 @@ TEST_F(TextClassifierTest, TextClassifierWithBert) { options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN( - TextClassifierResult negative_result, - classifier->Classify("unflinchingly bleak and desperate")); + TextClassifierResult negative_expected; + TextClassifierResult positive_expected; + +#ifdef _WIN32 + negative_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/0, /*score=*/0.956124, /*category_name=*/"negative"}, + {/*index=*/1, /*score=*/0.043875, /*category_name=*/"positive"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + positive_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/1, /*score=*/0.999951, /*category_name=*/"positive"}, + {/*index=*/0, /*score=*/0.000048, /*category_name=*/"negative"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); +#else negative_expected.classifications.emplace_back(Classifications{ /*categories=*/{ {/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"}, {/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}}, /*head_index=*/0, /*head_name=*/"probability"}); - ExpectApproximatelyEqual(negative_result, negative_expected); - - MP_ASSERT_OK_AND_ASSIGN( - TextClassifierResult positive_result, - classifier->Classify("it's a charming and often affecting journey")); - TextClassifierResult positive_expected; positive_expected.classifications.emplace_back(Classifications{ /*categories=*/{ {/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"}, {/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}}, /*head_index=*/0, /*head_name=*/"probability"}); +#endif // _WIN32 + + MP_ASSERT_OK_AND_ASSIGN( + TextClassifierResult negative_result, + classifier->Classify("unflinchingly bleak and desperate")); + ExpectApproximatelyEqual(negative_result, negative_expected); + + MP_ASSERT_OK_AND_ASSIGN( + TextClassifierResult positive_result, + classifier->Classify("it's a charming and often affecting journey")); ExpectApproximatelyEqual(positive_result, positive_expected); MP_ASSERT_OK(classifier->Close()); @@ -233,12 +251,17 @@ TEST_F(TextClassifierTest, BertLongPositive) { TextClassifierResult expected; std::vector categories; -// Predicted scores are slightly different on Mac OS. +// Predicted scores are slightly different across platforms. #ifdef __APPLE__ categories.push_back( {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"}); categories.push_back( {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"}); +#elif defined _WIN32 + categories.push_back( + {/*index=*/1, /*score=*/0.976686, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.023313, /*category_name=*/"negative"}); #else categories.push_back( {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"}); diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index 1ddea3358..533d829b9 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -75,7 +75,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) { text_embedder->Embed("it's a charming and often affecting journey")); ASSERT_EQ(result0.embeddings.size(), 1); ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512); +#ifdef _WIN32 + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon); +#else ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon); +#endif // _WIN32 MP_ASSERT_OK_AND_ASSIGN( auto result1, text_embedder->Embed("what a great and fantastic trip")); @@ -87,7 +91,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) { MP_ASSERT_OK_AND_ASSIGN( double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], result1.embeddings[0])); +#ifdef _WIN32 + EXPECT_NEAR(similarity, 0.971417, kSimilarityTolerancy); +#else EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy); +#endif // _WIN32 MP_ASSERT_OK(text_embedder->Close()); } @@ -160,8 +168,12 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { MP_ASSERT_OK_AND_ASSIGN( double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], result1.embeddings[0])); - // TODO: The similarity should likely be lower + // TODO: These similarity should likely be lower +#ifdef _WIN32 + EXPECT_NEAR(similarity, 0.98152, kSimilarityTolerancy); +#else EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy); +#endif // _WIN32 MP_ASSERT_OK(text_embedder->Close()); } From e485961c2d2e6166e3f111d69e15d3a127f4f352 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 2 Feb 2023 12:47:34 -0800 Subject: [PATCH 469/469] fixes spelling mistake PiperOrigin-RevId: 506697863 --- mediapipe/tasks/cc/core/base_task_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/cc/core/base_task_api.h b/mediapipe/tasks/cc/core/base_task_api.h index 1019c4fe9..92d41cc84 100644 --- a/mediapipe/tasks/cc/core/base_task_api.h +++ b/mediapipe/tasks/cc/core/base_task_api.h @@ -26,7 +26,7 @@ namespace mediapipe { namespace tasks { namespace core { -// The base calss of the user-facing mediapipe tasks api classes. +// The base class of the user-facing mediapipe tasks api classes. class BaseTaskApi { public: // Constructor.