From 7463e48fd42ebd33aa4838e6b7ca4d7eaabb1103 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 9 Mar 2023 02:39:21 -0800 Subject: [PATCH 01/63] Added some files necessary for the Face Stylizer implementation --- mediapipe/python/BUILD | 1 + mediapipe/tasks/python/test/vision/BUILD | 17 ++ .../python/test/vision/face_stylizer_test.py | 118 ++++++++ mediapipe/tasks/python/vision/BUILD | 19 ++ .../tasks/python/vision/face_stylizer.py | 254 ++++++++++++++++++ 5 files changed, 409 insertions(+) create mode 100644 mediapipe/tasks/python/test/vision/face_stylizer_test.py create mode 100644 mediapipe/tasks/python/vision/face_stylizer.py diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index f56e5b3d4..9d5ea26ad 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -94,6 +94,7 @@ cc_library( "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph", ] + select({ # TODO: Build text_classifier_graph and text_embedder_graph on Windows. "//mediapipe:windows": [], diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 48ecc30b3..19d592895 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -114,3 +114,20 @@ py_test( "@com_google_protobuf//:protobuf_python", ], ) + +py_test( + name = "face_stylizer_test", + srcs = ["face_stylizer_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:face_stylizer", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/test/vision/face_stylizer_test.py b/mediapipe/tasks/python/test/vision/face_stylizer_test.py new file mode 100644 index 000000000..3c39851dd --- /dev/null +++ b/mediapipe/tasks/python/test/vision/face_stylizer_test.py @@ -0,0 +1,118 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for face stylizer.""" + +import enum +import os +from unittest import mock + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import face_stylizer +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module + + +_BaseOptions = base_options_module.BaseOptions +_Image = image_module.Image +_FaceStylizer = face_stylizer.FaceStylizer +_FaceStylizerOptions = face_stylizer.FaceStylizerOptions +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions + +_MODEL = 'face_stylizer_model_placeholder.tflite' +_IMAGE = 'cats_and_dogs.jpg' +_STYLIZED_IMAGE = 'stylized_image_placeholder.jpg' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class FaceStylizerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _IMAGE))) + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _MODEL)) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _FaceStylizer.create_from_model_path(self.model_path) as stylizer: + self.assertIsInstance(stylizer, _FaceStylizer) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceStylizerOptions(base_options=base_options) + with _FaceStylizer.create_from_options(options) as stylizer: + self.assertIsInstance(stylizer, _FaceStylizer) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _FaceStylizerOptions(base_options=base_options) + _FaceStylizer.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _FaceStylizerOptions(base_options=base_options) + stylizer = _FaceStylizer.create_from_options(options) + self.assertIsInstance(stylizer, _FaceStylizer) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _STYLIZED_IMAGE), + (ModelFileType.FILE_CONTENT, _STYLIZED_IMAGE)) + def test_stylize(self, model_file_type, expected_detection_result_file): + # Creates stylizer. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceStylizerOptions(base_options=base_options) + stylizer = _FaceStylizer.create_from_options(options) + + # Performs face stylization on the input. + stylized_image = stylizer.detect(self.test_image) + # Comparing results. + self.assertTrue( + np.array_equal(stylized_image.numpy_view(), + self.test_image.numpy_view())) + # Closes the stylizer explicitly when the stylizer is not used in + # a context. + stylizer.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index eda8e290d..8ce0ef96e 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -152,3 +152,22 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "face_stylizer", + srcs = [ + "face_stylizer.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:face_stylizer_graph_options_py_pb2", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/face_stylizer.py b/mediapipe/tasks/python/vision/face_stylizer.py new file mode 100644 index 000000000..cd840fe85 --- /dev/null +++ b/mediapipe/tasks/python/vision/face_stylizer.py @@ -0,0 +1,254 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face stylizer task.""" + +import dataclasses +from typing import Callable, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.tasks.cc.vision.face_stylizer.proto import face_stylizer_graph_options_pb2 +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_FaceStylizerGraphOptionsProto = face_stylizer_graph_options_pb2.FaceStylizerGraphOptions +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_STYLIZED_IMAGE_NAME = 'stylized_image' +_STYLIZED_IMAGE_TAG = 'STYLIZED_IMAGE' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class FaceStylizerOptions: + """Options for the face stylizer task. + + Attributes: + base_options: Base options for the face stylizer task. + running_mode: The running mode of the task. Default to the image mode. + Face stylizer task has three running modes: + 1) The image mode for stylizing faces on single image inputs. + 2) The video mode for stylizing faces on the decoded frames of a video. + 3) The live stream mode for stylizing faces on a live stream of input + data, such as from camera. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + result_callback: Optional[ + Callable[[image_module.Image, image_module.Image, int], + None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceStylizerGraphOptionsProto: + """Generates an FaceStylizerOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + return _FaceStylizerGraphOptionsProto(base_options=base_options_proto) + + +class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face stylization on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceStylizer': + """Creates an `FaceStylizer` object from a TensorFlow Lite model and the default `FaceStylizerOptions`. + + Note that the created `FaceDetector` instance is in image mode, for + stylizing faces on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `FaceStylizer` object that's created from the model file and the default + `FaceStylizerOptions`. + + Raises: + ValueError: If failed to create `FaceStylizer` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceStylizerOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: FaceStylizerOptions) -> 'FaceStylizer': + """Creates the `FaceStylizer` object from face stylizer options. + + Args: + options: Options for the face stylizer task. + + Returns: + `FaceStylizer` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceStylizer` object from + `FaceStylizerOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME] + options.result_callback( + stylized_image_packet, image, + stylized_image_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=[ + ':'.join([_STYLIZED_IMAGE_TAG, _STYLIZED_IMAGE_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ], + task_options=options) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode == + _RunningMode.LIVE_STREAM), options.running_mode, + packets_callback if options.result_callback else None) + + def stylize( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None + ) -> image_module.Image: + """Performs face stylization on the provided MediaPipe Image. + + Only use this method when the FaceStylizer is created with the image + running mode. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + The stylized image. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face stylization failed to run. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()) + }) + return output_packets[_STYLIZED_IMAGE_NAME] + + def stylize_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None + ) -> image_module.Image: + """Performs face stylization on the provided video frames. + + Only use this method when the FaceStylizer is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. + + Returns: + The stylized image. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face stylization failed to run. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options) + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) + return output_packets[_STYLIZED_IMAGE_NAME] + + def stylize_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None + ) -> None: + """Sends live image data (an Image with a unique timestamp) to perform face stylization. + + Only use this method when the FaceStylizer is created with the live stream + running mode. The input timestamps should be monotonically increasing for + adjacent calls of this method. This method will return immediately after the + input image is accepted. The results will be available via the + `result_callback` provided in the `FaceStylizerOptions`. The + `stylize_async` method is designed to process live stream data such as camera + input. To lower the overall latency, face stylizer may drop the input + images if needed. In other words, it's not guaranteed to have output per + input image. + + The `result_callback` provides: + - The stylized image. + - The input image that the face stylizer runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. + + Raises: + ValueError: If the current input timestamp is smaller than what the face + stylizer has already processed. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options) + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) From 97870565081643d3acd6c64428b50e0485a57709 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 10 Mar 2023 10:17:03 -0800 Subject: [PATCH 02/63] Added the AudioRecord API --- mediapipe/tasks/python/audio/core/BUILD | 1 + .../python/audio/core/base_audio_task_api.py | 28 ++++ .../tasks/python/components/containers/BUILD | 5 + .../components/containers/audio_record.py | 126 ++++++++++++++++++ mediapipe/tasks/python/test/audio/BUILD | 2 + .../test/audio/audio_classifier_test.py | 14 ++ .../python/test/audio/audio_embedder_test.py | 14 ++ mediapipe/tasks/python/test/audio/core/BUILD | 27 ++++ .../test/audio/core/audio_record_test.py | 97 ++++++++++++++ 9 files changed, 314 insertions(+) create mode 100644 mediapipe/tasks/python/components/containers/audio_record.py create mode 100644 mediapipe/tasks/python/test/audio/core/BUILD create mode 100644 mediapipe/tasks/python/test/audio/core/audio_record_test.py diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 5b4203d7b..28dc4b960 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -34,5 +34,6 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/components/containers:audio_record", ], ) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index 5b08a2b76..80e8ad605 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -22,6 +22,7 @@ from mediapipe.python._framework_bindings import task_runner as task_runner_modu from mediapipe.python._framework_bindings import timestamp as timestamp_module from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.components.containers import audio_record _TaskRunner = task_runner_module.TaskRunner _Packet = packet_module.Packet @@ -126,6 +127,33 @@ class BaseAudioTaskApi(object): + self._running_mode.name) self._runner.send(inputs) + @staticmethod + def create_audio_record( + num_channels: int, + sample_rate: int, + required_input_buffer_size: int + ) -> audio_record.AudioRecord: + """Creates an AudioRecord instance to record audio stream. + + The returned AudioRecord instance is initialized and client needs to call + the appropriate method to start recording. + + Note that MediaPipe Audio tasks will up/down sample automatically to fit the + sample rate required by the model. The default sample rate of the MediaPipe + pretrained audio model, Yamnet is 16kHz. + + Args: + num_channels: The number of audio channels. + sample_rate: The audio sample rate. + required_input_buffer_size: The required input buffer size in number of + float elements. + + Raises: + ValueError: If there's a problem creating the AudioRecord instance. + """ + return audio_record.AudioRecord(num_channels, sample_rate, + required_input_buffer_size) + def close(self) -> None: """Shuts down the mediapipe audio task instance. diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 7108617ff..61163365c 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -23,6 +23,11 @@ py_library( srcs = ["audio_data.py"], ) +py_library( + name = "audio_record", + srcs = ["audio_record.py"], +) + py_library( name = "bounding_box", srcs = ["bounding_box.py"], diff --git a/mediapipe/tasks/python/components/containers/audio_record.py b/mediapipe/tasks/python/components/containers/audio_record.py new file mode 100644 index 000000000..824f36e3e --- /dev/null +++ b/mediapipe/tasks/python/components/containers/audio_record.py @@ -0,0 +1,126 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module to record audio in a streaming basis.""" +import threading +import numpy as np + +try: +# pylint: disable=g-import-not-at-top + import sounddevice as sd +# pylint: enable=g-import-not-at-top +except OSError as oe: + sd = None + sd_error = oe +except ImportError as ie: + sd = None + sd_error = ie + + +class AudioRecord(object): + """A class to record audio in a streaming basis.""" + + def __init__(self, channels: int, sampling_rate: int, + buffer_size: int) -> None: + """Creates an AudioRecord instance. + + Args: + channels: Number of input channels. + sampling_rate: Sampling rate in Hertz. + buffer_size: Size of the ring buffer in number of samples. + + Raises: + ValueError: if any of the arguments is non-positive. + ImportError: if failed to import `sounddevice`. + OSError: if failed to load `PortAudio`. + """ + if sd is None: + raise sd_error + + if channels <= 0: + raise ValueError('channels must be postive.') + if sampling_rate <= 0: + raise ValueError('sampling_rate must be postive.') + if buffer_size <= 0: + raise ValueError('buffer_size must be postive.') + + self._audio_buffer = [] + self._buffer_size = buffer_size + self._channels = channels + self._sampling_rate = sampling_rate + + # Create a ring buffer to store the input audio. + self._buffer = np.zeros([buffer_size, channels], dtype=float) + self._lock = threading.Lock() + + def audio_callback(data, *_): + """A callback to receive recorded audio data from sounddevice.""" + self._lock.acquire() + shift = len(data) + if shift > buffer_size: + self._buffer = np.copy(data[:buffer_size]) + else: + self._buffer = np.roll(self._buffer, -shift, axis=0) + self._buffer[-shift:, :] = np.copy(data) + self._lock.release() + + # Create an input stream to continuously capture the audio data. + self._stream = sd.InputStream( + channels=channels, + samplerate=sampling_rate, + callback=audio_callback, + ) + + @property + def channels(self) -> int: + return self._channels + + @property + def sampling_rate(self) -> int: + return self._sampling_rate + + @property + def buffer_size(self) -> int: + return self._buffer_size + + def start_recording(self) -> None: + """Starts the audio recording.""" + # Clear the internal ring buffer. + self._buffer.fill(0) + + # Start recording using sounddevice's InputStream. + self._stream.start() + + def stop(self) -> None: + """Stops the audio recording.""" + self._stream.stop() + + def read(self, size: int) -> np.ndarray: + """Reads the latest audio data captured in the buffer. + + Args: + size: Number of samples to read from the buffer. + + Returns: + A NumPy array containing the audio data. + + Raises: + ValueError: Raised if `size` is larger than the buffer size. + """ + if size > self._buffer_size: + raise ValueError('Cannot read more samples than the size of the buffer.') + elif size <= 0: + raise ValueError('Size must be positive.') + + start_index = self._buffer_size - size + return np.copy(self._buffer[start_index:]) diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 43f1d417c..d6e0788f2 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -29,6 +29,7 @@ py_test( "//mediapipe/tasks/python/audio:audio_classifier", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:audio_record", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", @@ -46,6 +47,7 @@ py_test( "//mediapipe/tasks/python/audio:audio_embedder", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:audio_record", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 75146547c..665a5ca13 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -27,6 +27,7 @@ from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.components.containers import audio_record from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -34,6 +35,7 @@ _AudioClassifier = audio_classifier.AudioClassifier _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData +_AudioRecord = audio_record.AudioRecord _BaseOptions = base_options_module.BaseOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode @@ -204,6 +206,18 @@ class AudioClassifierTest(parameterized.TestCase): self._read_wav_file(audio_file)) self._check_yamnet_result(classification_result_list) + @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock()) + def test_create_audio_record_from_classifier_succeeds(self, _): + # Creates AudioRecord instance using the classifier successfully. + with _AudioClassifier.create_from_model_path( + self.yamnet_model_path) as classifier: + self.assertIsInstance(classifier, _AudioClassifier) + record = classifier.create_audio_record(1, 16000, 16000) + self.assertIsInstance(record, _AudioRecord) + self.assertEqual(record.channels, 1) + self.assertEqual(record.sampling_rate, 16000) + self.assertEqual(record.buffer_size, 16000) + def test_max_result_options(self): with _AudioClassifier.create_from_options( _AudioClassifierOptions( diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 934cdc8db..2015d2bce 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -26,6 +26,7 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import audio_record from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions _AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioData = audio_data_module.AudioData +_AudioRecord = audio_record.AudioRecord _BaseOptions = base_options_module.BaseOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode @@ -165,6 +167,18 @@ class AudioEmbedderTest(parameterized.TestCase): self.assertLen(embedding_result0_list, 5) self.assertLen(embedding_result1_list, 5) + @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock()) + def test_create_audio_record_from_embedder_succeeds(self, _): + # Creates AudioRecord instance using the embedder successfully. + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + record = embedder.create_audio_record(1, 16000, 16000) + self.assertIsInstance(record, _AudioRecord) + self.assertEqual(record.channels, 1) + self.assertEqual(record.sampling_rate, 16000) + self.assertEqual(record.buffer_size, 16000) + def test_embed_with_yamnet_model_and_different_inputs(self): with _AudioEmbedder.create_from_model_path( self.yamnet_model_path) as embedder: diff --git a/mediapipe/tasks/python/test/audio/core/BUILD b/mediapipe/tasks/python/test/audio/core/BUILD new file mode 100644 index 000000000..14f2e4f6c --- /dev/null +++ b/mediapipe/tasks/python/test/audio/core/BUILD @@ -0,0 +1,27 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_test( + name = "audio_record_test", + srcs = ["audio_record_test.py"], + deps = [ + "//mediapipe/tasks/python/components/containers:audio_record", + ], +) diff --git a/mediapipe/tasks/python/test/audio/core/audio_record_test.py b/mediapipe/tasks/python/test/audio/core/audio_record_test.py new file mode 100644 index 000000000..dfa72a822 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/core/audio_record_test.py @@ -0,0 +1,97 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for audio_record.""" + +import numpy as np +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +from mediapipe.tasks.python.components.containers import audio_record + +_mock = unittest.mock + +_CHANNELS = 2 +_SAMPLING_RATE = 16000 +_BUFFER_SIZE = 15600 + + +class AudioRecordTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + # Mock sounddevice.InputStream + with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method: + self.mock_input_stream = _mock.MagicMock() + mock_input_stream_new_method.return_value = self.mock_input_stream + self.record = audio_record.AudioRecord(_CHANNELS, _SAMPLING_RATE, + _BUFFER_SIZE) + + # Save the initialization arguments of InputStream for later assertion. + _, self.init_args = mock_input_stream_new_method.call_args + + def test_init_args(self): + # Assert parameters of InputStream initialization + self.assertEqual( + self.init_args["channels"], _CHANNELS, + "InputStream's channels doesn't match the initialization argument.") + self.assertEqual( + self.init_args["samplerate"], _SAMPLING_RATE, + "InputStream's samplerate doesn't match the initialization argument.") + + def test_life_cycle(self): + # Assert start recording routine. + self.record.start_recording() + self.mock_input_stream.start.assert_called_once() + + # Assert stop recording routine. + self.record.stop() + self.mock_input_stream.stop.assert_called_once() + + def test_read_succeeds_with_valid_sample_size(self): + callback_fn = self.init_args["callback"] + + # Create dummy data to feed to the AudioRecord instance. + chunk_size = int(_BUFFER_SIZE * 0.5) + input_data = [] + for _ in range(3): + dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float) + input_data.append(dummy_data) + callback_fn(dummy_data) + + # Assert read data of a single chunk. + recorded_audio_data = self.record.read(chunk_size) + self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1])) + + # Assert read all data in buffer. + recorded_audio_data = self.record.read(chunk_size * 2) + print(input_data[-2].shape) + expected_data = np.concatenate(input_data[-2:]) + self.assertTrue(np.array_equal(recorded_audio_data, expected_data)) + + def test_read_fails_with_invalid_sample_size(self): + callback_fn = self.init_args["callback"] + + # Create dummy data to feed to the AudioRecord instance. + dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float) + callback_fn(dummy_data) + + # Assert exception if read too much data. + with self.assertRaises(ValueError): + self.record.read(_BUFFER_SIZE + 1) + + +if __name__ == "__main__": + absltest.main() From f56a3088e374d0c9b22c964b3c2d732b350c3294 Mon Sep 17 00:00:00 2001 From: Kinar R <42828719+kinaryml@users.noreply.github.com> Date: Sat, 11 Mar 2023 00:02:41 +0530 Subject: [PATCH 03/63] Update audio_record_test.py --- mediapipe/tasks/python/test/audio/core/audio_record_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/test/audio/core/audio_record_test.py b/mediapipe/tasks/python/test/audio/core/audio_record_test.py index dfa72a822..579e4c582 100644 --- a/mediapipe/tasks/python/test/audio/core/audio_record_test.py +++ b/mediapipe/tasks/python/test/audio/core/audio_record_test.py @@ -1,10 +1,10 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, From 9d8200014802969ee7c9a8242aa2ee1f6cc88384 Mon Sep 17 00:00:00 2001 From: Kinar R <42828719+kinaryml@users.noreply.github.com> Date: Sat, 11 Mar 2023 00:04:07 +0530 Subject: [PATCH 04/63] Update BUILD --- mediapipe/tasks/python/test/audio/core/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/test/audio/core/BUILD b/mediapipe/tasks/python/test/audio/core/BUILD index 14f2e4f6c..41d0755ee 100644 --- a/mediapipe/tasks/python/test/audio/core/BUILD +++ b/mediapipe/tasks/python/test/audio/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 78e48825ae14139131eb0d308e0da6286139342c Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sat, 11 Mar 2023 07:39:34 -0800 Subject: [PATCH 05/63] Make create_audio_record not a static method --- mediapipe/tasks/python/audio/core/base_audio_task_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index 80e8ad605..73b1273cd 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -127,8 +127,8 @@ class BaseAudioTaskApi(object): + self._running_mode.name) self._runner.send(inputs) - @staticmethod def create_audio_record( + self, num_channels: int, sample_rate: int, required_input_buffer_size: int From 444cd00ee66bf2716e33f6d8907fd2d59dc473bf Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 21 Mar 2023 23:15:55 -0700 Subject: [PATCH 06/63] Moved audio_record.py to tasks/python/audio/core --- mediapipe/tasks/python/audio/core/BUILD | 7 ++++++- .../{components/containers => audio/core}/audio_record.py | 0 mediapipe/tasks/python/audio/core/base_audio_task_api.py | 2 +- mediapipe/tasks/python/components/containers/BUILD | 5 ----- mediapipe/tasks/python/test/audio/BUILD | 4 ++-- mediapipe/tasks/python/test/audio/audio_classifier_test.py | 2 +- mediapipe/tasks/python/test/audio/audio_embedder_test.py | 2 +- 7 files changed, 11 insertions(+), 11 deletions(-) rename mediapipe/tasks/python/{components/containers => audio/core}/audio_record.py (100%) diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 28dc4b960..461ffb8b5 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -23,6 +23,11 @@ py_library( srcs = ["audio_task_running_mode.py"], ) +py_library( + name = "audio_record", + srcs = ["audio_record.py"], +) + py_library( name = "base_audio_task_api", srcs = [ @@ -30,10 +35,10 @@ py_library( ], deps = [ ":audio_task_running_mode", + ":audio_record", "//mediapipe/framework:calculator_py_pb2", "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/tasks/python/core:optional_dependencies", - "//mediapipe/tasks/python/components/containers:audio_record", ], ) diff --git a/mediapipe/tasks/python/components/containers/audio_record.py b/mediapipe/tasks/python/audio/core/audio_record.py similarity index 100% rename from mediapipe/tasks/python/components/containers/audio_record.py rename to mediapipe/tasks/python/audio/core/audio_record.py diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index 73b1273cd..aa5b28fcb 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -21,8 +21,8 @@ from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import timestamp as timestamp_module from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module +from mediapipe.tasks.python.audio.core import audio_record from mediapipe.tasks.python.core.optional_dependencies import doc_controls -from mediapipe.tasks.python.components.containers import audio_record _TaskRunner = task_runner_module.TaskRunner _Packet = packet_module.Packet diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 61163365c..7108617ff 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -23,11 +23,6 @@ py_library( srcs = ["audio_data.py"], ) -py_library( - name = "audio_record", - srcs = ["audio_record.py"], -) - py_library( name = "bounding_box", srcs = ["bounding_box.py"], diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index d6e0788f2..6bf69a278 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -27,9 +27,9 @@ py_test( ], deps = [ "//mediapipe/tasks/python/audio:audio_classifier", + "//mediapipe/tasks/python/audio/core:audio_record", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", - "//mediapipe/tasks/python/components/containers:audio_record", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", @@ -45,9 +45,9 @@ py_test( ], deps = [ "//mediapipe/tasks/python/audio:audio_embedder", + "//mediapipe/tasks/python/audio/core:audio_record", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", - "//mediapipe/tasks/python/components/containers:audio_record", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 665a5ca13..8880c02f4 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -25,9 +25,9 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode +from mediapipe.tasks.python.audio.core import audio_record from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.containers import audio_record from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 2015d2bce..f1e8ec6e7 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -25,8 +25,8 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio.core import audio_task_running_mode +from mediapipe.tasks.python.audio.core import audio_record from mediapipe.tasks.python.components.containers import audio_data as audio_data_module -from mediapipe.tasks.python.components.containers import audio_record from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils From da70497f3514b03c16b986ef157addea280077aa Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 22 Mar 2023 21:15:04 -0700 Subject: [PATCH 07/63] Updated Face Stylizer implementation and tests --- .../python/test/vision/face_stylizer_test.py | 8 +++----- mediapipe/tasks/python/vision/BUILD | 2 +- mediapipe/tasks/python/vision/face_stylizer.py | 8 +++++--- mediapipe/tasks/testdata/vision/BUILD | 1 + .../vision/face_stylization_dummy.tflite | Bin 0 -> 21430 bytes 5 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite diff --git a/mediapipe/tasks/python/test/vision/face_stylizer_test.py b/mediapipe/tasks/python/test/vision/face_stylizer_test.py index 3c39851dd..32643821f 100644 --- a/mediapipe/tasks/python/test/vision/face_stylizer_test.py +++ b/mediapipe/tasks/python/test/vision/face_stylizer_test.py @@ -36,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions -_MODEL = 'face_stylizer_model_placeholder.tflite' +_MODEL = 'face_stylization_dummy.tflite' _IMAGE = 'cats_and_dogs.jpg' _STYLIZED_IMAGE = 'stylized_image_placeholder.jpg' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' @@ -104,11 +104,9 @@ class FaceStylizerTest(parameterized.TestCase): stylizer = _FaceStylizer.create_from_options(options) # Performs face stylization on the input. - stylized_image = stylizer.detect(self.test_image) + stylized_image = stylizer.stylize(self.test_image) # Comparing results. - self.assertTrue( - np.array_equal(stylized_image.numpy_view(), - self.test_image.numpy_view())) + # TODO: # Closes the stylizer explicitly when the stylizer is not used in # a context. stylizer.close() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index e21171fc2..f89ef04da 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -162,7 +162,7 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/vision/image_segmenter/proto:face_stylizer_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/face_stylizer.py b/mediapipe/tasks/python/vision/face_stylizer.py index cd840fe85..1393982da 100644 --- a/mediapipe/tasks/python/vision/face_stylizer.py +++ b/mediapipe/tasks/python/vision/face_stylizer.py @@ -124,8 +124,10 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): return image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME] + stylized_image = packet_getter.get_image(stylized_image_packet) + options.result_callback( - stylized_image_packet, image, + stylized_image, image, stylized_image_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( @@ -173,7 +175,7 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): _NORM_RECT_STREAM_NAME: packet_creator.create_proto(normalized_rect.to_pb2()) }) - return output_packets[_STYLIZED_IMAGE_NAME] + return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME]) def stylize_for_video( self, @@ -209,7 +211,7 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): packet_creator.create_proto(normalized_rect.to_pb2()).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) - return output_packets[_STYLIZED_IMAGE_NAME] + return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME]) def stylize_async( self, diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 097acad43..f77e1780c 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -170,6 +170,7 @@ filegroup( "face_detection_short_range.tflite", "face_landmark.tflite", "face_landmark_with_attention.tflite", + "face_stylization_dummy.tflite", "face_landmarker.task", "hair_segmentation.tflite", "hand_landmark_full.tflite", diff --git a/mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite b/mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite new file mode 100644 index 0000000000000000000000000000000000000000..2523a3a5d16fe29058a1beae2433d3a2e3fa4ac6 GIT binary patch literal 21430 zcmZ5{2{hKv_y3a+5|XkcQPz?zMbCRKib|Feg(6$lNOmgwzGvU}ElN_UxG9$`Y*?(R+|S7(F)p!KLjQ72@(6LiCiy4%2$w{;6zmZc z=Hu_>?C#;|8tNDFFCQf_3Ba(TkCgnJt+sFUbG1>J0%J%>BPptY6|HLW*?&%TyFWg0ZrT5p*S|&o=RE&+JGnEPY|UcnXlO5;+>`_V?SE2M+8WFL-KYQXb7j{) zF(JjJYyWK%tHj;+-@Czd-e>aSbqSH zZQpVSCHpktg227O*jh}y*M7LG~wY5X{rD~k>;yI8Z$A<^k zo8r-?0eVHCi=01{LWJt)LZjPXDkPo7tlTYvOU?c0#Q19LkiJK=$B&Vbd-cE`62}~Q zLiGb9fGyNRlV+9Uln!AW-MR++Eh#x==K&GoA&}IWiH6Q`*!(O2b62zB$+jP~-_;#k ztvMve(;DtaisHbocH-Hcj@56=ktak3tk}&|BI*x)9otL>DkWG}tci)@AIW|`m}Vf>xy?EOo(%}&CBdta#7)4A|iX*Ogh3PEeU8~E{xy1+n(%h#S5ot78nDVgbg@EP znH%cv*z9XW`WtrPjUGenU?)=bH<#$Z>aT)Zh2XwC5Y){I39%cd}w? zWOqTV?oG~ao14_>_j`JvyBJtba;Ng9YU1f^6}VEbgwcgLxHo$(Rvhdn1DVU%u^+ZE zY1;%Lx4i`W#kSFPHCkxp*+hAN$>HblBJ34Xz%uKF5cj8$Shkg+k=h=H_eKi{UYm+n zymH~Ho-1_>Y{XY{Qeil|9Obi|IZEt}$g)ZUxo{Ke!fQu)R(8?KWvRI6UKO0&*n%r+ zv{%H0x8Ua1Y@EI&8#6-Yk-V04Z12~_n;absS<#G39rPK=$Tj#@G9Ie3N62caHMle| z98Wl|haEEpsQ;Oh%)2v|xZ`Fryt%a=wD0Xl)-x?)-@bt!I&MydY|A-kdY2(}oe%89 z9#Y&;4dTLUapn3ZQnT3)r?zyEQ;%%WK3o>A9I7ETm01w-+XafuFVVhJ{^V@sFG?>| z!=_Y83>v?{h!54HLGBV5-5gCV)lzVvJPEZ*B4OK?Zeq3}2bMIJ!Pbsa)OG5oIbZX@ z%+drm>!%V{bkr$>pEt=8S4TYR!r;jpS2+4zVOZF22-OjPxp##-iXHcW=6mVjGph#M z7MEgkwG4I9h@$C53DB!uLUuBWXtGWa*Q_5UL9S&`n$t%cuQy{=Ffm!rC7(bhYqk(A(c)wpnMfO$FV+(T7>8crOvD2{z_T!tBt>jwm zJT&=I1@hnHsIGP-_;qap$EGGw)~y9wt#WYbo5$F_bHN<-DwNLO2Zqmo5%Im===0C- zuIOV8pg2_(?B>dV@vchTE&QECtg)v7UAeepzzwUC9q?ywJUJg^j%}^3pgU5Cmx>v1 zR0^R^HGA!CCgw4Ffj(WeQSu;2N_AoEKezFT^T{ z=!Y7JD%c8pzPMoSFLC&&uRwO3*vVnFnL*9;rLfS-0PU`~0FOcm{V=5lJ~zq3v?tDx zwKtEX&EJPT{q;<_eG%h(pcWi$V=(+%0J_I|!;ty{2z)i2{GPoGm#p4EdiQ(5*@ij@ z@U7)6KbMWy9>p;Q(y3H-s1h=^dP2)5eM}pcgXhh1P;RmsTKPg@%0@3-Vx54owOcuA zFGh*@O${9SAOsycVt8uPY_x?b(6e$rjPY&2>T@C3X=_3@{ZNJe>idkR={{IE;Xz8e zmqO|$2CW+PA--&YHVI9`o^8Edru>_6O}IAxEa73&v-2RzG8WfNZ-igsY~3;0*)rIhM$SW)VBG{ODFR zP;LPEY#$JKGfX0CT_BFN2&LNuK)GX-HoV?PRd*NRd*4LV{q})gPb1i4`;933C4rDd zHGGR(4K3lhsG5`ud6)?!I$31Xsbc8zF2bU1&q<{i}lcA z=_-(n9${GQaYl7*3_UTRj3X_c}^yDKA?Gi$Zv~^IdHXm6=t4{Tg_j2`i6ZsvN z2KvSUu=9aBymvfFEB2lt$M>4x8`mT{*x^VH+x{W1P9;LnhgH}%;EHn}H?hBXKIRlO z6w?lqX56IwoOoYH;&*5l?BU*DXQw1$uhSrrzm!8PJ709^@|rVQ8f1u)jHOIGs_6>oTG$~XgsznA3CXJzJT8Tu?M{(cf+g3$5ekH z16Z@SaC9CYAQ9(tag$dp95^-&bo3hWg<3fTm8B!+wj|bln}WUdtKoW|GxZ;vK?FdW zJ~y?6){KiJ>*+q)lfQ$0r%&nV6laL+&c<~{p(tsdjd%0wQKu{ktez|(p8IB_CtnWO zOA3=HX*Fy;zXPmH7SZ^@WLU6q9vn-L!~UanFi}&4jb&9RHC{*i4+_!4hUVD+X9*ZY zkI>L5HrTV_J$t6FCm#4b2Q$?3(1yPpt>#zYpoJzQ8{`I8q+gRW=FRxyyB11}=i}(r zy+q4455%oAV8BS7bV-Dx@lYjxE@%L8zGldg-v#ZDPmtyx#mIvN@ZwuPi5xpX@4k?M zoo$ZDnh`}~zkgvo_mm;ORxUP$*20VE2uSQcOTL=c0&R>%Sx*%hIlB?sQx6bw`40Iz zJqVnf4pA|s<4oj(<g ztdmayXc(0fPzCvqvYanTQ}OnPbo}(rAN3?`@kB=v)P0r2j$$@+X#S!GH-*5v(j7c4 z>#_Zu9%70+e$N&`?>-$6Q#1peo>^d_k_kP1b&O2^H!A-*gu}Bt46j{`h1TRUSff2k zJ0p{jbyyG-KCZ@7*Pw-`2pZVkM6gKV>6vd0tUPSeM2<+!<06!5C6yM6A`tE8_gc1<> zwVw#`Tp%(FZlAjQ=pn86TL-rDD#?oX3$W;W9RB=fiF_JqD4NktvNuZMly`u(`Q_l8 zoCc~+LPY$K2fW%nM#Iu)P=(9s`1gJga&!*C&hHsmcR&k=&lQjh4^_#G@=C&YCK+D0 z7NFJDBTSK?Gm5IKB5y|y9RD7IuUn0w>wYoD92zBK;u$c=pN?forC^?|2%J}Zc!|lt zh2{^4g2e*t{<#2hAK&J@T-kv89(Y24>sBz>xR2(uUvWfq_h9Dfttj@u8H5*XfSG@i zY4QgkZmfBzG!X(}I@R#ga3_?^y+TYc1i`V@W|;0%3hjR1=>nNqSl^Y7w-#1GsB;jq zzNf6XKk$?C{(VTNeCwdTQw!ma(M}u-3F73-{Gxv>1yA*OFG7jOewbLB2x8rf;2-O! zUAs0cqzk|B&<$S?aP>wh9Y4O9T;B1O{#j8%i^nukb!QJP&ML%F9y2Hu6NP7g^6)ZG z0!YiaR=>@p3cUq7l9y>J)26u|I~G9LjnH1c!JtQCqs_e17^Nh zD)wpwp!w?%Pz}k4Gn!Ax^9{u?lyID$$y!T#bk0%5#ZI93Jr$BRoM-52VKCfT!t|>> z=iKD#{8{FLC~{5?hYpuO&YoI&dC_!eTPY0@x_8)yBKg>`(}T?X>I+gls&w5-3C?X+ z0osgL!&;I+thX89_i1i4VwVk^pE;9!h?hjK&2bR5u>ia>=VSW29Lg!Lfu*0)u8)WE~Xw>yLe$(I11R=b}4X__Tf6WBs}JA15vy7K!Tno_BmBiZ~q)PD=0;TO>)3XG8;F`SOdQ<8?N4y zfg?{+No!>t-784RqnfWIS>ru<__6|3EzXi~Q>34iY$3xu0X5MREgmJqy#hTdJhl_o z6#9@WH={6FZ7QbQ7~?f43Z+E&tvLuX;lQhNiL|=7U`g~Ia+?j1n=p-}`#NBO$7R}Vc?7Pn^@h6g zYWTV?7Y)zXl5Fh+sNz|V&r1W~%#<7Sr}j0Xb+DfaTwe+LZuR8e#%9=^k;1kYj3dE& zKT*DQYtZqFG6eKA&{tU&)<1_Ub%tDM|jgYvsU2YJUg_;BDihp zhi~L4El=Id7M!yYBy}oaR-;m(ReXT z-!sI_t@1=(@eOeEiUojy8$^T~!adpDaLZ^0xPNn^zu(HBx?Ti4DaipvHk*bO=HaD) zPU0?nlE{jw!}ninAn8Xs6>>|)sK)zbuj34Swe}G6W?3=z>a=HCo9o$dLV3giX zab~gwvhkT0_YD<22X}_D(3npM&788pv#lCSR@kEH+aKh~F-5#(uZKtFjPO@<9p-Gf zPegaM;=EJy$;H|vTvJqmQQP#urCNmg9&I8|e1FpWNByD7KnM<J!1Ix{fhfUd3tP6QiNg!lW$B4>!*#rMY5{$W|p)R1yHxt`ow%^@ikZ z!)5yUs3iEdN5ahBGPpI>9$Q#nsp^6?5VJuI;$GRp_lOMi{HF=~-E3fNjz86WU_<{D z&qsm4Y54U@6$n>a0RNIaMkew(;la{KT4gu zQ{i%d7SxO^VMXvkZhNY#RE>Kfi!H$WbjDH5PAG+7W&Gk zVZ!qLxb2!C*w)EF_Mo#q`gqL#5@ z=&AA#Z7A9Wk#28kZl@{CEmy}KXTz{uxe7womy&*s0rE6?1$1j^;Eic{7~!sr&hG=r z>b(1evEeZPh>PQCPAY`_Nd>j#C9t=HYm?@_<*2OMMUJdm0=g?t!Y(f-=HR_>JUh`( zUM|YQaJ!v2@O?L4A@9fuYKP(Pro*-_f9zTvM|MaRg7o1UKQewV%{(*TmNbB`9y_&zLMLAu?6P9RCn!vccg34GIasc798`;A9g#%E+gE>h+L$ z^AL@&wWi}C!LT6i2MxYyOLy;Ej-@%7FtMHMdw|c(RgKNq_WT0pcSHevzMqHuIg&KW zuoR4{OCW}*LBJL^y`g!W9PEEUgZi}b>!wVI_uwZNd%H(L& zA8kfol^g73|I}jJ$26FEya3y~wsHMguE-iM;q?F3g9f`~5R}@Cp^C>K_4Z5Bps0oM z(+oi2uQA5UhJ%&eJ|J36WO@8tIQTFQZa3RuTjdh=>dlwQkrrb(!L>&g-U=n#PcG)( z@5wYja2YvheU^;Q+)V8AKf2Ig+ zpYw42Fe~WkXNp+cng%;(dQ;O-f}Ou2;iTXqJX-32tXb9Id$)w1w$H~`2I|mtX*s7m zU=!$F_kzIQC@ zf@ZG*F~H(C`Re$dI7$ga_g_9h?q2^)=AJLLI1uhCVKlBd!>)kPJAzfYrbEJjan5+)LM*tXTIz(9pxA6ZP|i9}q*@JNrTPFlzts@!#S(~rPb{jRS%Hy>wcwp% z&FMK?(RFMHfsbgMSHk?{RA_Kr47s1= z@Q*_^87MqR-F4Q&=8Wmgo9C#8_Gma-c0J|97sIc`Z!y)1&Ou7R&c1`qcU-gKznZy%Wl6F52vaiudT@n z(JeE$ez#maGO+{(N84zYUp}TBZvovE8N@of7W<>!$re#lR5qOs3stj_#q!hXkQKsL zwcB80e>Rhwnuj@_zsQ{P4lsZ^)byY~J0){5`dw?LZcD#1s~$)){$fku*tC4G|N4@= zxI0Wjx5?opO+KujScZ$Wo$(4+SNmM?XKsbRq_+$U;HFs=SD9vzQ_D;7N{|Meu6M%- zt}jMI(T2!+@{o?TCFrP>35o>;;Iu&+c`rUD5-W#E?H?K78R=!>_U#3pBVjPSp%Jf- zG+@ED4^%2Ul-TU=XZL+_B^r*0;q-nNB&I(|Jt?I_l4-DH@kZQm*BnQa>xfTSBRFK4 zqtR1s&Z?`07-NU0h_<-ni zH_;K>23RY8gjB4HgA|2)2-j%m`Y**mJS83OT-w8|ZJYsotJT{3vlSSL8vU`jGL}lGbo-{^zT+ zz`8UKTsFoMJByFxruI&n&8Z-ySqj@LXVT1l`cQtkoAWp_4zx~tp=7TQbUrHKRH>Wc zyQErJ-=YY0hM6e&)Dh_!6Og$RPSk@#@aI-%?7t^Y`5xUSbCwpM^}1&?x8wmyoAZ(m zKJ8~3+m15>`KBadNChlSLdlK<5zhKqwXkeh7q=Dnk{f$sF?yLh_68M#V~9SSHMfK6 zOGV(OY>9h^$~oFkmZ9&UG^}rIf=ka_Iq#0=VYTcZ;`K`z2Sjd?@SjQ~wO}vVBO`m^6h{LHT&bu zENU!=)`Qvj`=%%v#Lt8sT$>`$k;PNaKylt0q za>A1^#rYjM)BJ~+-H`ybkR^Cu#2lqJN5Z@4N1TQq*QtcNExP6=!fj`RB=@=ohjW)d2?*=&=F-glEKvzD{{$%Ocj~% zoR5*ub8&l#9k?8w1xsQ}A##TVQxHnPp8C<_>9NF9R1!V*5719y<(%T7UPgfT32oe& zh!IkwG~=HQXtgR729iI~4I(ATZq6ZNO?j{YD!{BW1XF%Hb{Tc(gI?3+>B z=#iffrALEt=%D~C2}r=K4l8(mK$IS#S`fNzIVK*fXJYp9fRnf^HY^#Yjn>PdbnzUT zbE2M!zn_D@9MaJ&)v1ellmuEHLNx53Y))PCYKSYzK^j}mJUQwGtG~IRZn!jVG7f=h z!Fp6r^ayAh%t!YAxghVPK+M7waO1r^^!clY?YB!%&H6EQo3#Xo8|S0MqHe1E>Kl!` zc!WL|S^#|a!YR&J$pi@9=ltZx)Yh&FhCsncJpW_`YOFd#brbe-jKbx?BtMbVT`~c6 zYZu%sah(VmETveM&5t;l;!pIC8R`UVF0?UwLFf5!nelxOykX zE+3!je4;kjY%st&89PLT*pmHLuyx5k_~;S~@oD*(`@NOg{t$$-TszvGYq#v&y9S~j zSYRoCE8YBIH$3-zNOsRI#DRTTq=fV`xtAqD=iOHJ^3F}Tt~nO1jzrUyJTjPhB#|x} znhq^D!f~_74R(@g9M}b|f$y($z(sBwxn9)CPTk2OR8=3YrcZ~9MXHx*bFEPwA7HD-z-DoNBDRihUcL9WH`F@%WXO;^+7bH;8J`N;w zxpDrDiqLXk2Ily0fCs{b;1>IUYP21K-dCqdE888sh4#?5OYf4ekE1bShYb8=2|#7~ zUB;(R2Os)u0ym38aKzIIchr@D_QB_;)_0VE%2Y!PV+xVumjOBF67YDJ9$q+fohWjB zQ$shZu<&^h=gC_Rjdm4)&z?ExTg-IjO7H_K6-zEIqX9OaWa`CL zC^r8Db^Ry_iK<%2w>XB7y~<#Fz!MCLl0n;ljJB<{$3MIIuq8kZ>R!l#)u$?)aX5&u zXI0R3F9YD?G*38p!4o%&`hnVWZro|m495SS;M_i>iTAXw(HP;^q{Lf-I;f8_Jer~0 zzDuFzrvhck1Td^ebFs(3jWMQH;2HITntM3m-IzwKsc(YI`&OJjx~zpSO_tyl{7Qd*i-$E5yV1|6ozW9r z0qI&>P>)lMwu1+$+v8z6dSRGZ{wNS1yDWvDZk2F5xq{wth=ndeu6_P_^{L}ezcQ?g zhd`!dI{1v3;49ntaJIV`gB)5xsH^~=Z`nz2ZE(U~-co3Ah{vzmY@(G{fC(35aM0s4 zDcmjrMuk_2=ep@=-jYXb6u7Z;hY0it=KN_|N52 z^y9&J_)z_pX{cfFbo&qDHM4^Go>mQ8AN&UoHi^Y55J{%z0+wN(LN6WMG|q`EZbBoa zOvW|C9S~kkB_Nh31g)-%-?IZV1i-DLa0eestf({nJ+EE^KT^I#Y z`jRN}T#*z_&q7;11sHKm#{)s7u*ANZ4Y2UfK=&hL7oVT%*H0?E8@#jqeMEO-iULm31E4A_3ub*U5L;#gfC`Qw` zE_!fdIh@<4in>eLH1k>{bJ@fL-YYLax0kE2;_G>K4>9euBMc<1$)S(;^Wf)Mh44xx zorFiJ;gk^{j5H0W#sMbKxKRdnr`h1Xr`sW|?J>C_pGMC6cQM7=s_4#Y1Kjym4K)s> zGeI}^5k<9gG=66ok-57c#=dJKyR#ZPp&1PwWkKSBJa$-`(O2ThxUke2X8ejpxw2#s zV=aP+Juj%-?Rs3+6OH{!H#tVHjTocA5RUcTSdf2mj876;qtDvjs=%48c9Ng!X^_#gx=~z_nXuIAXa2YmN?b z1a6$C!u;M~DkOr|*H^+SSvD%TY`~F)LQpMbg1O=CD0Xla1}+?+m%Gd1%%2aOkok|P zpYu7guQZGfu**oUTo!W5kC29*nXuqS2HcI{`qfQEurj_1?_SNpep`R=R=h{~jBZht z!F-(d>k#?qJRN5rN~1ipW9V?fR9J9v4k$-&0nOK6>5JxE_&O^I;+IbY3lC}d;2j5F zt|}7=9YfBKo7?eCktp7g+5xM7N6>_2hS(RB0iNp?P{BKTaKT#)4hndY;xH9ZJIrGC ze%nj-AIQKIZ+Uzy;*0aUeviQiy0eBdjCqwE?O+r{-?tp7loi*8d>Wec>6Zo$6-J0V1T8k|XSK=qM!VgljV zx^FJTJ6NNI=N)=ewg9&2<)e$|G#X~vjD1@#Q@P)S>$i}_p0*Tb^hz<-t}z>rjmTh6I2184gONc*Crq#(MV6L72(L#BARSB z3;H%2fo<^rJIkTw{x44Zg;7$K(FiKXxwx~1Qqb8M zK%};0(eGPs(a%o@=n2Pzl%;!|87s^GI-gPMa;th>eQGuA|g>5Z6Rj$^5|@arJ-b1?&x9VoYzzj?iKySv0M>`OZ|@LmqB&TMS&ANb^z2+% zo%My>F;0bve=Nxpr`b5vQh<@c+30vE7T<4*$27wAPl(ilaY7RL z?HY|gwT^N`OcdaZtO{}#`N6?oW>6;emA-l42#*B#QRj;ahKp1}yg>xykvH_Qs4UD@ zYJ|(W1^9hm4S-QIeW7cILFP92+bSNu=tkpqsbtzAqz9XV<3T1-0P3cMqC@gxNIRyK1ArH)*cG0ct_hU-SR!Et;9T#KoY9v>fM1D}JOxM;xFjjk|}Ebzn}zsk=7;9`$OH{^BBslu)(ke)XD5N?TrLi5|7I`@%4e!=4Y9O~>$3@Ngqfq& z93JghSZGp<(NkX1@i>5buWXch`G)S`FQ>=am2oYz0gqg6hI?ZzFB0&BV`G9fwVyakTRD$nzF;95q>sRy;Xm=@!6{baQwr z&BYN+`-kp`t)l}X1(-Wg1i?pD;OiHH!A|z@K7J!znUaWuLXjMSY*RA)elPy5@&*0K zJ+Nns5vnIVp*>eJF))dX*Ig_MJX>|pId2v`oKt|Oa<(A(X@CU^shE0s6S6+$f^E-9 znrdc==ahHgscqAtuy-{b4a!IU?(ZZ$rkG~MYcm-c-Auk6!_Lic8_t=_XD#H`Xb#xC$Gh zzA*~X`9y}T4gB}=$^9}b5F9*!N2f-krceg(ZvR1jxbY~p?cu~MDV=>@ewg--SmS}m zDNJXWF9xfslcjx%oE_^cP;w#}^pXr9VRR3i$UfXvxMTtD4GKlkt6w-9^tRI2u|Uq> zU&&bX`~m6zn1-2=p0v$$AB{IQr=Mzs;M7igcy`AdUvf?`XE)Zu@Y+;3)h$Mx3q-K2 zD-|4X_@F|OKgqqQMx1BZLDoA#q7_mPfdK>5;Ou@d+F{K74M8z}IqFMHQcf{Oycg*6 z_D_uC?^N<}-a#y=Ng=M&*JA}2D{<*UDX^^V$ue6_sH>ZUpV+BngT@77EdGSL?)t*v zamXbnik0B%$(2;`Sw0L`pQh*hYB+bqo8k5kt`9cj8)rs$BqY7r2J7F7f>U8Wdi;sR zm+hr!V5d#T{uIFdxz+GFIS{LEZGpj`)sXl!oaVHLQkTH284f5^*aK2Fs^>*uF;}RFiisG%X#oZrX zUo6J;Jcps=_zwJTkbza_7el5~5#g_0h4=Dx8NK%l;l$DT%+(UgIVW%cs>QRwNFoly z@++}twkQ-u7ouWnGtvwX)LI}2O~Azg!VU& zlTrR(91``I@XjrVn;BVw zVdccF(2}@M+6xj;Ip8W)$m8PO{hYd3*G8DAS(@167|C&=Y~Tr%0e+>!bhJVO#40pN zpUG+}mXS`{N*AN-4>w>P&*hv>FU2#qq8Kn<%s5%cqHo1=Tp(Qs!?ji@D}RZavR&zl z0)owIzBF!kJ*ppeLEg8qAat@2Jg-WCA2+_SJ(%m43cN?`{Zk;e@_JW~v=B9$$V3-a z6Xy1%P>|Xy4=oA(WJtsx8{ghyL_V6s`9VF{xIVE6HGVL7K(~vqu4iN3cYP|ffgkeCZgZlUeQ;PPhu$MnSWvv4WF6SfEXqj$ z!F<4r2ZcdxLohyC9EXd>lp$@ZCQ1x)ef!q?u&JVs1Q-6~FgKH_|3e4zZdJ|rkQ>WC$c?F;;;Y5^OA?_`S{Af7 zE1;UjJ+i|23K97OApEwg%OoKVmniOnT6Ptxo%O)mVrR$`#ku%m#TMuo>LEOZX>?@f z2)qByGb(g0md;G7f}Go{@Imf1%1*CjX1>aY_TG1#!E=t(*e3@+e7Q*J(Ohbzw~Nk| zu7O`WvO&{z1_}yzaz-T_@Tt{Il&|TgLNUi_?q~(@uhB%UC6-kE?M?FNmkquPEJ9rx z06B?e=-ikFNy5h1D6|`gKd*oXzhvO3ivc7}YsR9BH|c_)8t}FWfd{5_DCM(&PGhe| zbpd_&y}W`vS(}Ng4!pSzL={A4m~Dim<|DvznOlWm(1wA^aFjSL57)wrAa|}bb5J4>Rqrjs1FjLc@TMFr z##mV0rw!*M(y5m7P}c`TL24S6gwkp;Ff3C6dxZPRrF+e2V<`eU)oFB%v;ho$wZY;| z%OF@;0Q-MZqU2i#y%(2Y#F;r5(O`tHDtR$}qLryp(e4r~y+%)-RK|@)ida37kBMB& z(=A#`BJRX*GP3nVVrbAx zJF?{XT=HdiGUWv3L86-(+vVeKaLn6)UVAG+W^Fv@OPvGwZEdG=iVsM_(mZGnZGl3^ z(_~gpI%O|=NZow*V4B(x2^-?Z-h@*y!`BiT=LnJwqX8UGt_&9EuA$>kLSRUo`%V)% z49{3W7|4y0TvyNlLzji1-7AUezYcRwnP_otJOgwy<;7a7V3N`oLMyg>;<)GCpbx$- zK}m*%+E<)msgwkYZ^^)c(=$PSNPx;WH{uePA}}@Hi*`MUoYdDN^pk`YT>ED+Zqfe4 z?1~yX^{HP9rVT8icU|)#reP)OybHhu4u1GRDF+hHrqGRL9mGo_o~pii#3UrGhkAvr zWXX+#^i@dQfLKfEb?uWO$kICCVrQA1s33jyWfJNLhqS@n&Yt|q= z;VyzZgiD~H$db`s&IiMX!twUzSj?{~#rGa5oXpLd;I}6kPv0%Z!60dTo4*W0&SYck z!#+lJO#r#;a|j|+vcdJterVX&3R=-Ka0xpT8POzejF;8rKerC;~@05}`Vl6QLcOZN{ zVT?Z-cEP!K7fJWchaBPS4($7*ZbWOr56Y`2ht2#4=tDVXuTYJP0p#bnVGcZ>{hauZ)xzOlb<{{^ zDXRS{!{3Y2!8vF%xZa9|)J2v=qg9>O&415K>kPz>(P8@bSPtkL8h~Sq2Y7hQz^5)_ z)Zk(oDcU@pdS*8gkM=4?oiJOhPUVRZ0RgT902+@?gv_ z1rFYRNM>Xd!(DC+V-roJr_&3-z=LZOR{W&i7kjDX)neFPkx0kmwO~uV87{T{#u+RT zhwOG6xGiabyED&F(dtEb>B>)<>avF}+-i!kL=*LQib86JKk5WZqWmL!5_@hXEpy~z z+$|n-@uZ5w9u2Oo+MJI@zq`5yThxfhOcs8=V1pG3#WbS50eXfMn9+`%q$Fz#>9Y*Q zyCus}#_SP|IN3^7FDRk+;5%w%Y)l>3%ERi?=_r0%5(4JCf|c$?`mmw^^OwtDismlT z9lRAho*ttPL4hz+Dir@nAETMOa4qT`&3RHl|7^^_&Tn_fVAulG`*fIO zx0ZssZy~lMEPz<9O*$pE9{)KOhA~5a#CK^e(~%zrC$q|dE?5j=)rRy5zY8`TI746Q zPe-GrkEyBKE$U*mmX;P6p;TlGrQu1WMoxxAjMU*tkjFu#ZRq)a3-0@Ej=ZnA<8JxO zh>yFY-zkQR+t~rr*8@sCIzyV18X)0OF@U7nx3+2RYa(kyB!v=vA4*MUWGJrT6mMX}PmWT;IDBF?^{b+^25 z;CDXs%?qJb-1nk*e-)jdV}WK>=c!ZpA_(Fa06w<{+_Xsf{}pmB@Kj}O1E(mZNih^E zN=Z?!iL=+!%_W3Xx(G?3lu@}$8uw%n3el<8K}am z#nxYmjp<*r_X%y|GH8~sld!DcZ1fFdYoR58V{M#qZk`0Z?gvu&F`Dr5$|$@nxh*6I z+kt-eI8nj*@!)QxO0$RhWAfcBXuYrtRvOjUU2B#j>t8F7J=r-he|Z@8TU z&9kZe>SS0`SV_-4QpFb!lu?_Ul4}2Gjb*PU(1Oh^#9lQW=jjAN`0-3!^Sl7V%C$+T zQ5ZHZK0rF<4#6UGKd2v&gm#aVNv>%c)+o%xo)??wQkN`@^A5zN)?-jom$BMp8(2-HZspgYFJG@@{dYUOZM}(oO>_i-J;Wv6|R6NA! z46d`?%-)mQr=u>LqdLUiDRt)U!SNywn02jPWKrw{A)8|)`K@NmvmTlyDGP3QbHJUx3F^R!#)Fkl#( z)Jnj0coBZvH~^~DS7TmLG5TAwal_BLP&-`?GRvGnZ}}yu&#Pp(mlgvPX$R?yx=J71 zh=dAp0<86(2mNEj(no%dD7Q@l<9`o__?bm$Bl{JTAl~h&kU6A+7>7neyW1@Cz$pV=1sj}lYb4M(XBxk30+nHF+bf0-f}Nq> zqF0%eT+c`VTZNx#kf?(OWlpE6G#a&3ma`e(9oY5_3U{X_38~Fnq_GL?IeyO@B{Uav zJ2v1A^-e0U>;|02&?aC!5v>G`nLX0sjUY49jLyS9{MXDPCrS5J+(HF ze##Fc3%y{Xt}GmQeSlWRHj%9dr{I3)Ea}{eF0#rl2f7~21`CfB7(OB#r~CP$bZr4N z>#Qd~-txhoTXoQtKN@SUis_}{P12hyEkO777T7n^4TlXFPvi9hrJk!(&}!ab>bO7$ z4b<2gX_rzV$gCC&F1J#(v!RgavjNE&-VE-FP+zb;VmVHNd&$J~cOD zYdKms5r<_-c*0zh=ym}X$W+3htO#tIeFQvqc1lYE4^suhWw;}`3Y<@;V29ZekXe#L z%9gp{?~6UK_iQ%o8`6*b7-E5;FB_<{i5iZw&xUhvE#Z%z#khQ814%sV0UZwV;QG=^ z2v-qH&0hJTSQ0=h=4>HxE>#exu|ufGG?-SA36Gk8B&ub5p^ZJyie8t(p&Mc%tV@E% zIfHQi(MY)08-Y*b<#6Tg1(1tKOUvU?fMw?=^Wh0rt+8({MheKlL zR+!K|4$gULU{&g2s0vR3g$^BjQ<0B1-?&2VeF=2h`a|mYA87H*GSG~c!QSpt8lSft zTiz$pgU5Hk>^FhnaWe~shz-%!MF(4}^HF7IKKxR#09=yWSzEZ_%=sf=|MV?%ky;|X zZTu~UJedURwyuJp#w~cga37I;pAN(N5{ZLrDQrBvloUJ0Lg~$SRB=WHJ@&Rx@RF;d z-mV7v|K3eU|2qSmHtUmT)4qqP{nwMIL2RGs^k_V2wORDbs36G9 zSOwvi(lD$x5-pO?P@UoFpg)E26ztup`YSUKPDil$gF3i-?=}g2JP>Gbue7aU5op_G zNNXyyvEk9*w8gRz_V(;V`;3|R>f0=wV%SD5rn9{%<<+3e=Fc_VcVS_)5ytfM#E0vv zr8;F#$y?Q#=>Fm&8K%PaFjj2Ay(^T#zH|{er`1D#4Tz);ipOEiP&wFZ?7@bl56SwA zu4K)K!O+lhMS5xOXldlqTDFeS0OOrCq0+Pj_n(*ow|;vfgjEkEfnA~4V7VUyRKCNJ zj`?V4nGc%#eCgBG=F-BeX_zZyvayviPFy!1I!z=rSY8R_v-60>_!ctDaXm?Wa)ss=!PC;Y2=t0F!;hRG;u;0*<$Yme_OJB)r#B6szIJ;E{+AI z6dTf}*+Oh5dtmI8WDunlVo}RJY8Rf46WHFOsm6XdRaOu0*vU(kM#bahsBO?1SWJFm z&xh+TYe>t*cyu4L1AU*m!DY8&qQACx(5YU#ap$OX>|QbuAE?!k5ZzK}b!YqNvfV(g zT12y7TfrOY3+cM3AvAZ&9+X>d2&QH&r0}g6Zg19sK&=y`wO$L?>URjU2G7OXk_hx^ z*ato)JMluH64;M$M7OJ<5EQ=*`#!Dnr9oJ7RjLGX@vP*g!GmzA&R<1 z_^-1pj{S}i+ueInt>_(D{znq^bU6S97bam>!3w-&=Le1ovG9Y|3bbelhditkntLKK)(X&aA(0!FWFj54vKZflg%YrkZjU;DI8oJ)DA&>jm`mfF3XMGXahbsv#Z&RA__p4k9P^gzka)@Uk%rf5~W+R>ruIi3_4} z>+CnuFspYotw)w-9W%lP;UqnBBLe=|I3IRR9)S1fB!GF2H7PdmBDUk^z_U|&c-`6v zu6CUyhLxrGt9cx;Ktc3<(agLE7@sxC!rd*ah2-w{^i*&hZuQT?WS9Xb-%q1X@7@bfEP|jp zp6&ls$ipbRMDQ&w1Y_mz;q}o}FfvdeDqAn;(j6MX!aoYa__c{9Pa+I?!{Y7!_)wNVn;FS|!mU{l-~<*C=IdTcwQor&8&f zF>+LXoeSty8)Id_9?~2VC%v61L6f)1&{e91*ZmSiNo+q=eLp4iw=xp_UaE< zPDyZ<(=@OyABHh*VUY4Si`9$Qsr%^8x+M!;u(2#2GKJH0;Da1tgH0*END{;O^jx^- zyctfvOrb5Y2kTa-FUKBY1CJO-jPtZOm+SF@NMYiU9irA+W8q64R}4E}U^rKrvxL^F zyeyNluQEP~TPy?b4y>yQz9M+EUbkp^>>$;MonF_z{~9YVBa=HL-BJ=Xu2$9NU7c?9 zRzbmM?D2-kGr|x{KOyRB=}!+gwwUwl$^|3A*{z~B)F0|<%GF47=>tLYp(+{g=vV8v zcAV(Rg_U)?H8zP_di(m?nQ*S0lcW7YZ7T+<^9>5&$T7|?ju_QZ^eeiM&9oEOca~q$9^T#-Qk8{p= zIj)1-=00$8ALp)d+q^tWvg6r@9+M%T=y9B<_JOVTg*eV)+mGAi_V|&fflS(& z!HrDbvdm&8j!ZZoj{D;L?xVll7jBmeFGHB{^Q=039~@@oxeW_;{r|_O8PhrPk-%X*u4CCh=;AL!L_?b3nX1LHdAXNC;$vOVs!LxmQ zzIvX=f{FE4Pk$=TYC5J73dCz%Su9w$~a~(NTl3&I9 X=g0Pt>F+X6o`?2-QUPOc%E Date: Wed, 22 Mar 2023 21:47:38 -0700 Subject: [PATCH 08/63] Removed unit tests --- .../python/test/vision/face_stylizer_test.py | 116 ------------------ 1 file changed, 116 deletions(-) delete mode 100644 mediapipe/tasks/python/test/vision/face_stylizer_test.py diff --git a/mediapipe/tasks/python/test/vision/face_stylizer_test.py b/mediapipe/tasks/python/test/vision/face_stylizer_test.py deleted file mode 100644 index 32643821f..000000000 --- a/mediapipe/tasks/python/test/vision/face_stylizer_test.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for face stylizer.""" - -import enum -import os -from unittest import mock - -import numpy as np -from absl.testing import absltest -from absl.testing import parameterized - -from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.core import base_options as base_options_module -from mediapipe.tasks.python.test import test_utils -from mediapipe.tasks.python.vision import face_stylizer -from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module -from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module - - -_BaseOptions = base_options_module.BaseOptions -_Image = image_module.Image -_FaceStylizer = face_stylizer.FaceStylizer -_FaceStylizerOptions = face_stylizer.FaceStylizerOptions -_RUNNING_MODE = running_mode_module.VisionTaskRunningMode -_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions - -_MODEL = 'face_stylization_dummy.tflite' -_IMAGE = 'cats_and_dogs.jpg' -_STYLIZED_IMAGE = 'stylized_image_placeholder.jpg' -_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' - - -class ModelFileType(enum.Enum): - FILE_CONTENT = 1 - FILE_NAME = 2 - - -class FaceStylizerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.test_image = _Image.create_from_file( - test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _IMAGE))) - self.model_path = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _MODEL)) - - def test_create_from_file_succeeds_with_valid_model_path(self): - # Creates with default option and valid model file successfully. - with _FaceStylizer.create_from_model_path(self.model_path) as stylizer: - self.assertIsInstance(stylizer, _FaceStylizer) - - def test_create_from_options_succeeds_with_valid_model_path(self): - # Creates with options containing model file successfully. - base_options = _BaseOptions(model_asset_path=self.model_path) - options = _FaceStylizerOptions(base_options=base_options) - with _FaceStylizer.create_from_options(options) as stylizer: - self.assertIsInstance(stylizer, _FaceStylizer) - - def test_create_from_options_fails_with_invalid_model_path(self): - with self.assertRaisesRegex( - RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): - base_options = _BaseOptions( - model_asset_path='/path/to/invalid/model.tflite') - options = _FaceStylizerOptions(base_options=base_options) - _FaceStylizer.create_from_options(options) - - def test_create_from_options_succeeds_with_valid_model_content(self): - # Creates with options containing model content successfully. - with open(self.model_path, 'rb') as f: - base_options = _BaseOptions(model_asset_buffer=f.read()) - options = _FaceStylizerOptions(base_options=base_options) - stylizer = _FaceStylizer.create_from_options(options) - self.assertIsInstance(stylizer, _FaceStylizer) - - @parameterized.parameters( - (ModelFileType.FILE_NAME, _STYLIZED_IMAGE), - (ModelFileType.FILE_CONTENT, _STYLIZED_IMAGE)) - def test_stylize(self, model_file_type, expected_detection_result_file): - # Creates stylizer. - if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(model_asset_path=self.model_path) - elif model_file_type is ModelFileType.FILE_CONTENT: - with open(self.model_path, 'rb') as f: - model_content = f.read() - base_options = _BaseOptions(model_asset_buffer=model_content) - else: - # Should never happen - raise ValueError('model_file_type is invalid.') - - options = _FaceStylizerOptions(base_options=base_options) - stylizer = _FaceStylizer.create_from_options(options) - - # Performs face stylization on the input. - stylized_image = stylizer.stylize(self.test_image) - # Comparing results. - # TODO: - # Closes the stylizer explicitly when the stylizer is not used in - # a context. - stylizer.close() - - -if __name__ == '__main__': - absltest.main() From 613bcf99f48072154c7043f2de8f8c11f0fc7094 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 22 Mar 2023 21:49:11 -0700 Subject: [PATCH 09/63] Removed model --- .../vision/face_stylization_dummy.tflite | Bin 21430 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite diff --git a/mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite b/mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite deleted file mode 100644 index 2523a3a5d16fe29058a1beae2433d3a2e3fa4ac6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 21430 zcmZ5{2{hKv_y3a+5|XkcQPz?zMbCRKib|Feg(6$lNOmgwzGvU}ElN_UxG9$`Y*?(R+|S7(F)p!KLjQ72@(6LiCiy4%2$w{;6zmZc z=Hu_>?C#;|8tNDFFCQf_3Ba(TkCgnJt+sFUbG1>J0%J%>BPptY6|HLW*?&%TyFWg0ZrT5p*S|&o=RE&+JGnEPY|UcnXlO5;+>`_V?SE2M+8WFL-KYQXb7j{) zF(JjJYyWK%tHj;+-@Czd-e>aSbqSH zZQpVSCHpktg227O*jh}y*M7LG~wY5X{rD~k>;yI8Z$A<^k zo8r-?0eVHCi=01{LWJt)LZjPXDkPo7tlTYvOU?c0#Q19LkiJK=$B&Vbd-cE`62}~Q zLiGb9fGyNRlV+9Uln!AW-MR++Eh#x==K&GoA&}IWiH6Q`*!(O2b62zB$+jP~-_;#k ztvMve(;DtaisHbocH-Hcj@56=ktak3tk}&|BI*x)9otL>DkWG}tci)@AIW|`m}Vf>xy?EOo(%}&CBdta#7)4A|iX*Ogh3PEeU8~E{xy1+n(%h#S5ot78nDVgbg@EP znH%cv*z9XW`WtrPjUGenU?)=bH<#$Z>aT)Zh2XwC5Y){I39%cd}w? zWOqTV?oG~ao14_>_j`JvyBJtba;Ng9YU1f^6}VEbgwcgLxHo$(Rvhdn1DVU%u^+ZE zY1;%Lx4i`W#kSFPHCkxp*+hAN$>HblBJ34Xz%uKF5cj8$Shkg+k=h=H_eKi{UYm+n zymH~Ho-1_>Y{XY{Qeil|9Obi|IZEt}$g)ZUxo{Ke!fQu)R(8?KWvRI6UKO0&*n%r+ zv{%H0x8Ua1Y@EI&8#6-Yk-V04Z12~_n;absS<#G39rPK=$Tj#@G9Ie3N62caHMle| z98Wl|haEEpsQ;Oh%)2v|xZ`Fryt%a=wD0Xl)-x?)-@bt!I&MydY|A-kdY2(}oe%89 z9#Y&;4dTLUapn3ZQnT3)r?zyEQ;%%WK3o>A9I7ETm01w-+XafuFVVhJ{^V@sFG?>| z!=_Y83>v?{h!54HLGBV5-5gCV)lzVvJPEZ*B4OK?Zeq3}2bMIJ!Pbsa)OG5oIbZX@ z%+drm>!%V{bkr$>pEt=8S4TYR!r;jpS2+4zVOZF22-OjPxp##-iXHcW=6mVjGph#M z7MEgkwG4I9h@$C53DB!uLUuBWXtGWa*Q_5UL9S&`n$t%cuQy{=Ffm!rC7(bhYqk(A(c)wpnMfO$FV+(T7>8crOvD2{z_T!tBt>jwm zJT&=I1@hnHsIGP-_;qap$EGGw)~y9wt#WYbo5$F_bHN<-DwNLO2Zqmo5%Im===0C- zuIOV8pg2_(?B>dV@vchTE&QECtg)v7UAeepzzwUC9q?ywJUJg^j%}^3pgU5Cmx>v1 zR0^R^HGA!CCgw4Ffj(WeQSu;2N_AoEKezFT^T{ z=!Y7JD%c8pzPMoSFLC&&uRwO3*vVnFnL*9;rLfS-0PU`~0FOcm{V=5lJ~zq3v?tDx zwKtEX&EJPT{q;<_eG%h(pcWi$V=(+%0J_I|!;ty{2z)i2{GPoGm#p4EdiQ(5*@ij@ z@U7)6KbMWy9>p;Q(y3H-s1h=^dP2)5eM}pcgXhh1P;RmsTKPg@%0@3-Vx54owOcuA zFGh*@O${9SAOsycVt8uPY_x?b(6e$rjPY&2>T@C3X=_3@{ZNJe>idkR={{IE;Xz8e zmqO|$2CW+PA--&YHVI9`o^8Edru>_6O}IAxEa73&v-2RzG8WfNZ-igsY~3;0*)rIhM$SW)VBG{ODFR zP;LPEY#$JKGfX0CT_BFN2&LNuK)GX-HoV?PRd*NRd*4LV{q})gPb1i4`;933C4rDd zHGGR(4K3lhsG5`ud6)?!I$31Xsbc8zF2bU1&q<{i}lcA z=_-(n9${GQaYl7*3_UTRj3X_c}^yDKA?Gi$Zv~^IdHXm6=t4{Tg_j2`i6ZsvN z2KvSUu=9aBymvfFEB2lt$M>4x8`mT{*x^VH+x{W1P9;LnhgH}%;EHn}H?hBXKIRlO z6w?lqX56IwoOoYH;&*5l?BU*DXQw1$uhSrrzm!8PJ709^@|rVQ8f1u)jHOIGs_6>oTG$~XgsznA3CXJzJT8Tu?M{(cf+g3$5ekH z16Z@SaC9CYAQ9(tag$dp95^-&bo3hWg<3fTm8B!+wj|bln}WUdtKoW|GxZ;vK?FdW zJ~y?6){KiJ>*+q)lfQ$0r%&nV6laL+&c<~{p(tsdjd%0wQKu{ktez|(p8IB_CtnWO zOA3=HX*Fy;zXPmH7SZ^@WLU6q9vn-L!~UanFi}&4jb&9RHC{*i4+_!4hUVD+X9*ZY zkI>L5HrTV_J$t6FCm#4b2Q$?3(1yPpt>#zYpoJzQ8{`I8q+gRW=FRxyyB11}=i}(r zy+q4455%oAV8BS7bV-Dx@lYjxE@%L8zGldg-v#ZDPmtyx#mIvN@ZwuPi5xpX@4k?M zoo$ZDnh`}~zkgvo_mm;ORxUP$*20VE2uSQcOTL=c0&R>%Sx*%hIlB?sQx6bw`40Iz zJqVnf4pA|s<4oj(<g ztdmayXc(0fPzCvqvYanTQ}OnPbo}(rAN3?`@kB=v)P0r2j$$@+X#S!GH-*5v(j7c4 z>#_Zu9%70+e$N&`?>-$6Q#1peo>^d_k_kP1b&O2^H!A-*gu}Bt46j{`h1TRUSff2k zJ0p{jbyyG-KCZ@7*Pw-`2pZVkM6gKV>6vd0tUPSeM2<+!<06!5C6yM6A`tE8_gc1<> zwVw#`Tp%(FZlAjQ=pn86TL-rDD#?oX3$W;W9RB=fiF_JqD4NktvNuZMly`u(`Q_l8 zoCc~+LPY$K2fW%nM#Iu)P=(9s`1gJga&!*C&hHsmcR&k=&lQjh4^_#G@=C&YCK+D0 z7NFJDBTSK?Gm5IKB5y|y9RD7IuUn0w>wYoD92zBK;u$c=pN?forC^?|2%J}Zc!|lt zh2{^4g2e*t{<#2hAK&J@T-kv89(Y24>sBz>xR2(uUvWfq_h9Dfttj@u8H5*XfSG@i zY4QgkZmfBzG!X(}I@R#ga3_?^y+TYc1i`V@W|;0%3hjR1=>nNqSl^Y7w-#1GsB;jq zzNf6XKk$?C{(VTNeCwdTQw!ma(M}u-3F73-{Gxv>1yA*OFG7jOewbLB2x8rf;2-O! zUAs0cqzk|B&<$S?aP>wh9Y4O9T;B1O{#j8%i^nukb!QJP&ML%F9y2Hu6NP7g^6)ZG z0!YiaR=>@p3cUq7l9y>J)26u|I~G9LjnH1c!JtQCqs_e17^Nh zD)wpwp!w?%Pz}k4Gn!Ax^9{u?lyID$$y!T#bk0%5#ZI93Jr$BRoM-52VKCfT!t|>> z=iKD#{8{FLC~{5?hYpuO&YoI&dC_!eTPY0@x_8)yBKg>`(}T?X>I+gls&w5-3C?X+ z0osgL!&;I+thX89_i1i4VwVk^pE;9!h?hjK&2bR5u>ia>=VSW29Lg!Lfu*0)u8)WE~Xw>yLe$(I11R=b}4X__Tf6WBs}JA15vy7K!Tno_BmBiZ~q)PD=0;TO>)3XG8;F`SOdQ<8?N4y zfg?{+No!>t-784RqnfWIS>ru<__6|3EzXi~Q>34iY$3xu0X5MREgmJqy#hTdJhl_o z6#9@WH={6FZ7QbQ7~?f43Z+E&tvLuX;lQhNiL|=7U`g~Ia+?j1n=p-}`#NBO$7R}Vc?7Pn^@h6g zYWTV?7Y)zXl5Fh+sNz|V&r1W~%#<7Sr}j0Xb+DfaTwe+LZuR8e#%9=^k;1kYj3dE& zKT*DQYtZqFG6eKA&{tU&)<1_Ub%tDM|jgYvsU2YJUg_;BDihp zhi~L4El=Id7M!yYBy}oaR-;m(ReXT z-!sI_t@1=(@eOeEiUojy8$^T~!adpDaLZ^0xPNn^zu(HBx?Ti4DaipvHk*bO=HaD) zPU0?nlE{jw!}ninAn8Xs6>>|)sK)zbuj34Swe}G6W?3=z>a=HCo9o$dLV3giX zab~gwvhkT0_YD<22X}_D(3npM&788pv#lCSR@kEH+aKh~F-5#(uZKtFjPO@<9p-Gf zPegaM;=EJy$;H|vTvJqmQQP#urCNmg9&I8|e1FpWNByD7KnM<J!1Ix{fhfUd3tP6QiNg!lW$B4>!*#rMY5{$W|p)R1yHxt`ow%^@ikZ z!)5yUs3iEdN5ahBGPpI>9$Q#nsp^6?5VJuI;$GRp_lOMi{HF=~-E3fNjz86WU_<{D z&qsm4Y54U@6$n>a0RNIaMkew(;la{KT4gu zQ{i%d7SxO^VMXvkZhNY#RE>Kfi!H$WbjDH5PAG+7W&Gk zVZ!qLxb2!C*w)EF_Mo#q`gqL#5@ z=&AA#Z7A9Wk#28kZl@{CEmy}KXTz{uxe7womy&*s0rE6?1$1j^;Eic{7~!sr&hG=r z>b(1evEeZPh>PQCPAY`_Nd>j#C9t=HYm?@_<*2OMMUJdm0=g?t!Y(f-=HR_>JUh`( zUM|YQaJ!v2@O?L4A@9fuYKP(Pro*-_f9zTvM|MaRg7o1UKQewV%{(*TmNbB`9y_&zLMLAu?6P9RCn!vccg34GIasc798`;A9g#%E+gE>h+L$ z^AL@&wWi}C!LT6i2MxYyOLy;Ej-@%7FtMHMdw|c(RgKNq_WT0pcSHevzMqHuIg&KW zuoR4{OCW}*LBJL^y`g!W9PEEUgZi}b>!wVI_uwZNd%H(L& zA8kfol^g73|I}jJ$26FEya3y~wsHMguE-iM;q?F3g9f`~5R}@Cp^C>K_4Z5Bps0oM z(+oi2uQA5UhJ%&eJ|J36WO@8tIQTFQZa3RuTjdh=>dlwQkrrb(!L>&g-U=n#PcG)( z@5wYja2YvheU^;Q+)V8AKf2Ig+ zpYw42Fe~WkXNp+cng%;(dQ;O-f}Ou2;iTXqJX-32tXb9Id$)w1w$H~`2I|mtX*s7m zU=!$F_kzIQC@ zf@ZG*F~H(C`Re$dI7$ga_g_9h?q2^)=AJLLI1uhCVKlBd!>)kPJAzfYrbEJjan5+)LM*tXTIz(9pxA6ZP|i9}q*@JNrTPFlzts@!#S(~rPb{jRS%Hy>wcwp% z&FMK?(RFMHfsbgMSHk?{RA_Kr47s1= z@Q*_^87MqR-F4Q&=8Wmgo9C#8_Gma-c0J|97sIc`Z!y)1&Ou7R&c1`qcU-gKznZy%Wl6F52vaiudT@n z(JeE$ez#maGO+{(N84zYUp}TBZvovE8N@of7W<>!$re#lR5qOs3stj_#q!hXkQKsL zwcB80e>Rhwnuj@_zsQ{P4lsZ^)byY~J0){5`dw?LZcD#1s~$)){$fku*tC4G|N4@= zxI0Wjx5?opO+KujScZ$Wo$(4+SNmM?XKsbRq_+$U;HFs=SD9vzQ_D;7N{|Meu6M%- zt}jMI(T2!+@{o?TCFrP>35o>;;Iu&+c`rUD5-W#E?H?K78R=!>_U#3pBVjPSp%Jf- zG+@ED4^%2Ul-TU=XZL+_B^r*0;q-nNB&I(|Jt?I_l4-DH@kZQm*BnQa>xfTSBRFK4 zqtR1s&Z?`07-NU0h_<-ni zH_;K>23RY8gjB4HgA|2)2-j%m`Y**mJS83OT-w8|ZJYsotJT{3vlSSL8vU`jGL}lGbo-{^zT+ zz`8UKTsFoMJByFxruI&n&8Z-ySqj@LXVT1l`cQtkoAWp_4zx~tp=7TQbUrHKRH>Wc zyQErJ-=YY0hM6e&)Dh_!6Og$RPSk@#@aI-%?7t^Y`5xUSbCwpM^}1&?x8wmyoAZ(m zKJ8~3+m15>`KBadNChlSLdlK<5zhKqwXkeh7q=Dnk{f$sF?yLh_68M#V~9SSHMfK6 zOGV(OY>9h^$~oFkmZ9&UG^}rIf=ka_Iq#0=VYTcZ;`K`z2Sjd?@SjQ~wO}vVBO`m^6h{LHT&bu zENU!=)`Qvj`=%%v#Lt8sT$>`$k;PNaKylt0q za>A1^#rYjM)BJ~+-H`ybkR^Cu#2lqJN5Z@4N1TQq*QtcNExP6=!fj`RB=@=ohjW)d2?*=&=F-glEKvzD{{$%Ocj~% zoR5*ub8&l#9k?8w1xsQ}A##TVQxHnPp8C<_>9NF9R1!V*5719y<(%T7UPgfT32oe& zh!IkwG~=HQXtgR729iI~4I(ATZq6ZNO?j{YD!{BW1XF%Hb{Tc(gI?3+>B z=#iffrALEt=%D~C2}r=K4l8(mK$IS#S`fNzIVK*fXJYp9fRnf^HY^#Yjn>PdbnzUT zbE2M!zn_D@9MaJ&)v1ellmuEHLNx53Y))PCYKSYzK^j}mJUQwGtG~IRZn!jVG7f=h z!Fp6r^ayAh%t!YAxghVPK+M7waO1r^^!clY?YB!%&H6EQo3#Xo8|S0MqHe1E>Kl!` zc!WL|S^#|a!YR&J$pi@9=ltZx)Yh&FhCsncJpW_`YOFd#brbe-jKbx?BtMbVT`~c6 zYZu%sah(VmETveM&5t;l;!pIC8R`UVF0?UwLFf5!nelxOykX zE+3!je4;kjY%st&89PLT*pmHLuyx5k_~;S~@oD*(`@NOg{t$$-TszvGYq#v&y9S~j zSYRoCE8YBIH$3-zNOsRI#DRTTq=fV`xtAqD=iOHJ^3F}Tt~nO1jzrUyJTjPhB#|x} znhq^D!f~_74R(@g9M}b|f$y($z(sBwxn9)CPTk2OR8=3YrcZ~9MXHx*bFEPwA7HD-z-DoNBDRihUcL9WH`F@%WXO;^+7bH;8J`N;w zxpDrDiqLXk2Ily0fCs{b;1>IUYP21K-dCqdE888sh4#?5OYf4ekE1bShYb8=2|#7~ zUB;(R2Os)u0ym38aKzIIchr@D_QB_;)_0VE%2Y!PV+xVumjOBF67YDJ9$q+fohWjB zQ$shZu<&^h=gC_Rjdm4)&z?ExTg-IjO7H_K6-zEIqX9OaWa`CL zC^r8Db^Ry_iK<%2w>XB7y~<#Fz!MCLl0n;ljJB<{$3MIIuq8kZ>R!l#)u$?)aX5&u zXI0R3F9YD?G*38p!4o%&`hnVWZro|m495SS;M_i>iTAXw(HP;^q{Lf-I;f8_Jer~0 zzDuFzrvhck1Td^ebFs(3jWMQH;2HITntM3m-IzwKsc(YI`&OJjx~zpSO_tyl{7Qd*i-$E5yV1|6ozW9r z0qI&>P>)lMwu1+$+v8z6dSRGZ{wNS1yDWvDZk2F5xq{wth=ndeu6_P_^{L}ezcQ?g zhd`!dI{1v3;49ntaJIV`gB)5xsH^~=Z`nz2ZE(U~-co3Ah{vzmY@(G{fC(35aM0s4 zDcmjrMuk_2=ep@=-jYXb6u7Z;hY0it=KN_|N52 z^y9&J_)z_pX{cfFbo&qDHM4^Go>mQ8AN&UoHi^Y55J{%z0+wN(LN6WMG|q`EZbBoa zOvW|C9S~kkB_Nh31g)-%-?IZV1i-DLa0eestf({nJ+EE^KT^I#Y z`jRN}T#*z_&q7;11sHKm#{)s7u*ANZ4Y2UfK=&hL7oVT%*H0?E8@#jqeMEO-iULm31E4A_3ub*U5L;#gfC`Qw` zE_!fdIh@<4in>eLH1k>{bJ@fL-YYLax0kE2;_G>K4>9euBMc<1$)S(;^Wf)Mh44xx zorFiJ;gk^{j5H0W#sMbKxKRdnr`h1Xr`sW|?J>C_pGMC6cQM7=s_4#Y1Kjym4K)s> zGeI}^5k<9gG=66ok-57c#=dJKyR#ZPp&1PwWkKSBJa$-`(O2ThxUke2X8ejpxw2#s zV=aP+Juj%-?Rs3+6OH{!H#tVHjTocA5RUcTSdf2mj876;qtDvjs=%48c9Ng!X^_#gx=~z_nXuIAXa2YmN?b z1a6$C!u;M~DkOr|*H^+SSvD%TY`~F)LQpMbg1O=CD0Xla1}+?+m%Gd1%%2aOkok|P zpYu7guQZGfu**oUTo!W5kC29*nXuqS2HcI{`qfQEurj_1?_SNpep`R=R=h{~jBZht z!F-(d>k#?qJRN5rN~1ipW9V?fR9J9v4k$-&0nOK6>5JxE_&O^I;+IbY3lC}d;2j5F zt|}7=9YfBKo7?eCktp7g+5xM7N6>_2hS(RB0iNp?P{BKTaKT#)4hndY;xH9ZJIrGC ze%nj-AIQKIZ+Uzy;*0aUeviQiy0eBdjCqwE?O+r{-?tp7loi*8d>Wec>6Zo$6-J0V1T8k|XSK=qM!VgljV zx^FJTJ6NNI=N)=ewg9&2<)e$|G#X~vjD1@#Q@P)S>$i}_p0*Tb^hz<-t}z>rjmTh6I2184gONc*Crq#(MV6L72(L#BARSB z3;H%2fo<^rJIkTw{x44Zg;7$K(FiKXxwx~1Qqb8M zK%};0(eGPs(a%o@=n2Pzl%;!|87s^GI-gPMa;th>eQGuA|g>5Z6Rj$^5|@arJ-b1?&x9VoYzzj?iKySv0M>`OZ|@LmqB&TMS&ANb^z2+% zo%My>F;0bve=Nxpr`b5vQh<@c+30vE7T<4*$27wAPl(ilaY7RL z?HY|gwT^N`OcdaZtO{}#`N6?oW>6;emA-l42#*B#QRj;ahKp1}yg>xykvH_Qs4UD@ zYJ|(W1^9hm4S-QIeW7cILFP92+bSNu=tkpqsbtzAqz9XV<3T1-0P3cMqC@gxNIRyK1ArH)*cG0ct_hU-SR!Et;9T#KoY9v>fM1D}JOxM;xFjjk|}Ebzn}zsk=7;9`$OH{^BBslu)(ke)XD5N?TrLi5|7I`@%4e!=4Y9O~>$3@Ngqfq& z93JghSZGp<(NkX1@i>5buWXch`G)S`FQ>=am2oYz0gqg6hI?ZzFB0&BV`G9fwVyakTRD$nzF;95q>sRy;Xm=@!6{baQwr z&BYN+`-kp`t)l}X1(-Wg1i?pD;OiHH!A|z@K7J!znUaWuLXjMSY*RA)elPy5@&*0K zJ+Nns5vnIVp*>eJF))dX*Ig_MJX>|pId2v`oKt|Oa<(A(X@CU^shE0s6S6+$f^E-9 znrdc==ahHgscqAtuy-{b4a!IU?(ZZ$rkG~MYcm-c-Auk6!_Lic8_t=_XD#H`Xb#xC$Gh zzA*~X`9y}T4gB}=$^9}b5F9*!N2f-krceg(ZvR1jxbY~p?cu~MDV=>@ewg--SmS}m zDNJXWF9xfslcjx%oE_^cP;w#}^pXr9VRR3i$UfXvxMTtD4GKlkt6w-9^tRI2u|Uq> zU&&bX`~m6zn1-2=p0v$$AB{IQr=Mzs;M7igcy`AdUvf?`XE)Zu@Y+;3)h$Mx3q-K2 zD-|4X_@F|OKgqqQMx1BZLDoA#q7_mPfdK>5;Ou@d+F{K74M8z}IqFMHQcf{Oycg*6 z_D_uC?^N<}-a#y=Ng=M&*JA}2D{<*UDX^^V$ue6_sH>ZUpV+BngT@77EdGSL?)t*v zamXbnik0B%$(2;`Sw0L`pQh*hYB+bqo8k5kt`9cj8)rs$BqY7r2J7F7f>U8Wdi;sR zm+hr!V5d#T{uIFdxz+GFIS{LEZGpj`)sXl!oaVHLQkTH284f5^*aK2Fs^>*uF;}RFiisG%X#oZrX zUo6J;Jcps=_zwJTkbza_7el5~5#g_0h4=Dx8NK%l;l$DT%+(UgIVW%cs>QRwNFoly z@++}twkQ-u7ouWnGtvwX)LI}2O~Azg!VU& zlTrR(91``I@XjrVn;BVw zVdccF(2}@M+6xj;Ip8W)$m8PO{hYd3*G8DAS(@167|C&=Y~Tr%0e+>!bhJVO#40pN zpUG+}mXS`{N*AN-4>w>P&*hv>FU2#qq8Kn<%s5%cqHo1=Tp(Qs!?ji@D}RZavR&zl z0)owIzBF!kJ*ppeLEg8qAat@2Jg-WCA2+_SJ(%m43cN?`{Zk;e@_JW~v=B9$$V3-a z6Xy1%P>|Xy4=oA(WJtsx8{ghyL_V6s`9VF{xIVE6HGVL7K(~vqu4iN3cYP|ffgkeCZgZlUeQ;PPhu$MnSWvv4WF6SfEXqj$ z!F<4r2ZcdxLohyC9EXd>lp$@ZCQ1x)ef!q?u&JVs1Q-6~FgKH_|3e4zZdJ|rkQ>WC$c?F;;;Y5^OA?_`S{Af7 zE1;UjJ+i|23K97OApEwg%OoKVmniOnT6Ptxo%O)mVrR$`#ku%m#TMuo>LEOZX>?@f z2)qByGb(g0md;G7f}Go{@Imf1%1*CjX1>aY_TG1#!E=t(*e3@+e7Q*J(Ohbzw~Nk| zu7O`WvO&{z1_}yzaz-T_@Tt{Il&|TgLNUi_?q~(@uhB%UC6-kE?M?FNmkquPEJ9rx z06B?e=-ikFNy5h1D6|`gKd*oXzhvO3ivc7}YsR9BH|c_)8t}FWfd{5_DCM(&PGhe| zbpd_&y}W`vS(}Ng4!pSzL={A4m~Dim<|DvznOlWm(1wA^aFjSL57)wrAa|}bb5J4>Rqrjs1FjLc@TMFr z##mV0rw!*M(y5m7P}c`TL24S6gwkp;Ff3C6dxZPRrF+e2V<`eU)oFB%v;ho$wZY;| z%OF@;0Q-MZqU2i#y%(2Y#F;r5(O`tHDtR$}qLryp(e4r~y+%)-RK|@)ida37kBMB& z(=A#`BJRX*GP3nVVrbAx zJF?{XT=HdiGUWv3L86-(+vVeKaLn6)UVAG+W^Fv@OPvGwZEdG=iVsM_(mZGnZGl3^ z(_~gpI%O|=NZow*V4B(x2^-?Z-h@*y!`BiT=LnJwqX8UGt_&9EuA$>kLSRUo`%V)% z49{3W7|4y0TvyNlLzji1-7AUezYcRwnP_otJOgwy<;7a7V3N`oLMyg>;<)GCpbx$- zK}m*%+E<)msgwkYZ^^)c(=$PSNPx;WH{uePA}}@Hi*`MUoYdDN^pk`YT>ED+Zqfe4 z?1~yX^{HP9rVT8icU|)#reP)OybHhu4u1GRDF+hHrqGRL9mGo_o~pii#3UrGhkAvr zWXX+#^i@dQfLKfEb?uWO$kICCVrQA1s33jyWfJNLhqS@n&Yt|q= z;VyzZgiD~H$db`s&IiMX!twUzSj?{~#rGa5oXpLd;I}6kPv0%Z!60dTo4*W0&SYck z!#+lJO#r#;a|j|+vcdJterVX&3R=-Ka0xpT8POzejF;8rKerC;~@05}`Vl6QLcOZN{ zVT?Z-cEP!K7fJWchaBPS4($7*ZbWOr56Y`2ht2#4=tDVXuTYJP0p#bnVGcZ>{hauZ)xzOlb<{{^ zDXRS{!{3Y2!8vF%xZa9|)J2v=qg9>O&415K>kPz>(P8@bSPtkL8h~Sq2Y7hQz^5)_ z)Zk(oDcU@pdS*8gkM=4?oiJOhPUVRZ0RgT902+@?gv_ z1rFYRNM>Xd!(DC+V-roJr_&3-z=LZOR{W&i7kjDX)neFPkx0kmwO~uV87{T{#u+RT zhwOG6xGiabyED&F(dtEb>B>)<>avF}+-i!kL=*LQib86JKk5WZqWmL!5_@hXEpy~z z+$|n-@uZ5w9u2Oo+MJI@zq`5yThxfhOcs8=V1pG3#WbS50eXfMn9+`%q$Fz#>9Y*Q zyCus}#_SP|IN3^7FDRk+;5%w%Y)l>3%ERi?=_r0%5(4JCf|c$?`mmw^^OwtDismlT z9lRAho*ttPL4hz+Dir@nAETMOa4qT`&3RHl|7^^_&Tn_fVAulG`*fIO zx0ZssZy~lMEPz<9O*$pE9{)KOhA~5a#CK^e(~%zrC$q|dE?5j=)rRy5zY8`TI746Q zPe-GrkEyBKE$U*mmX;P6p;TlGrQu1WMoxxAjMU*tkjFu#ZRq)a3-0@Ej=ZnA<8JxO zh>yFY-zkQR+t~rr*8@sCIzyV18X)0OF@U7nx3+2RYa(kyB!v=vA4*MUWGJrT6mMX}PmWT;IDBF?^{b+^25 z;CDXs%?qJb-1nk*e-)jdV}WK>=c!ZpA_(Fa06w<{+_Xsf{}pmB@Kj}O1E(mZNih^E zN=Z?!iL=+!%_W3Xx(G?3lu@}$8uw%n3el<8K}am z#nxYmjp<*r_X%y|GH8~sld!DcZ1fFdYoR58V{M#qZk`0Z?gvu&F`Dr5$|$@nxh*6I z+kt-eI8nj*@!)QxO0$RhWAfcBXuYrtRvOjUU2B#j>t8F7J=r-he|Z@8TU z&9kZe>SS0`SV_-4QpFb!lu?_Ul4}2Gjb*PU(1Oh^#9lQW=jjAN`0-3!^Sl7V%C$+T zQ5ZHZK0rF<4#6UGKd2v&gm#aVNv>%c)+o%xo)??wQkN`@^A5zN)?-jom$BMp8(2-HZspgYFJG@@{dYUOZM}(oO>_i-J;Wv6|R6NA! z46d`?%-)mQr=u>LqdLUiDRt)U!SNywn02jPWKrw{A)8|)`K@NmvmTlyDGP3QbHJUx3F^R!#)Fkl#( z)Jnj0coBZvH~^~DS7TmLG5TAwal_BLP&-`?GRvGnZ}}yu&#Pp(mlgvPX$R?yx=J71 zh=dAp0<86(2mNEj(no%dD7Q@l<9`o__?bm$Bl{JTAl~h&kU6A+7>7neyW1@Cz$pV=1sj}lYb4M(XBxk30+nHF+bf0-f}Nq> zqF0%eT+c`VTZNx#kf?(OWlpE6G#a&3ma`e(9oY5_3U{X_38~Fnq_GL?IeyO@B{Uav zJ2v1A^-e0U>;|02&?aC!5v>G`nLX0sjUY49jLyS9{MXDPCrS5J+(HF ze##Fc3%y{Xt}GmQeSlWRHj%9dr{I3)Ea}{eF0#rl2f7~21`CfB7(OB#r~CP$bZr4N z>#Qd~-txhoTXoQtKN@SUis_}{P12hyEkO777T7n^4TlXFPvi9hrJk!(&}!ab>bO7$ z4b<2gX_rzV$gCC&F1J#(v!RgavjNE&-VE-FP+zb;VmVHNd&$J~cOD zYdKms5r<_-c*0zh=ym}X$W+3htO#tIeFQvqc1lYE4^suhWw;}`3Y<@;V29ZekXe#L z%9gp{?~6UK_iQ%o8`6*b7-E5;FB_<{i5iZw&xUhvE#Z%z#khQ814%sV0UZwV;QG=^ z2v-qH&0hJTSQ0=h=4>HxE>#exu|ufGG?-SA36Gk8B&ub5p^ZJyie8t(p&Mc%tV@E% zIfHQi(MY)08-Y*b<#6Tg1(1tKOUvU?fMw?=^Wh0rt+8({MheKlL zR+!K|4$gULU{&g2s0vR3g$^BjQ<0B1-?&2VeF=2h`a|mYA87H*GSG~c!QSpt8lSft zTiz$pgU5Hk>^FhnaWe~shz-%!MF(4}^HF7IKKxR#09=yWSzEZ_%=sf=|MV?%ky;|X zZTu~UJedURwyuJp#w~cga37I;pAN(N5{ZLrDQrBvloUJ0Lg~$SRB=WHJ@&Rx@RF;d z-mV7v|K3eU|2qSmHtUmT)4qqP{nwMIL2RGs^k_V2wORDbs36G9 zSOwvi(lD$x5-pO?P@UoFpg)E26ztup`YSUKPDil$gF3i-?=}g2JP>Gbue7aU5op_G zNNXyyvEk9*w8gRz_V(;V`;3|R>f0=wV%SD5rn9{%<<+3e=Fc_VcVS_)5ytfM#E0vv zr8;F#$y?Q#=>Fm&8K%PaFjj2Ay(^T#zH|{er`1D#4Tz);ipOEiP&wFZ?7@bl56SwA zu4K)K!O+lhMS5xOXldlqTDFeS0OOrCq0+Pj_n(*ow|;vfgjEkEfnA~4V7VUyRKCNJ zj`?V4nGc%#eCgBG=F-BeX_zZyvayviPFy!1I!z=rSY8R_v-60>_!ctDaXm?Wa)ss=!PC;Y2=t0F!;hRG;u;0*<$Yme_OJB)r#B6szIJ;E{+AI z6dTf}*+Oh5dtmI8WDunlVo}RJY8Rf46WHFOsm6XdRaOu0*vU(kM#bahsBO?1SWJFm z&xh+TYe>t*cyu4L1AU*m!DY8&qQACx(5YU#ap$OX>|QbuAE?!k5ZzK}b!YqNvfV(g zT12y7TfrOY3+cM3AvAZ&9+X>d2&QH&r0}g6Zg19sK&=y`wO$L?>URjU2G7OXk_hx^ z*ato)JMluH64;M$M7OJ<5EQ=*`#!Dnr9oJ7RjLGX@vP*g!GmzA&R<1 z_^-1pj{S}i+ueInt>_(D{znq^bU6S97bam>!3w-&=Le1ovG9Y|3bbelhditkntLKK)(X&aA(0!FWFj54vKZflg%YrkZjU;DI8oJ)DA&>jm`mfF3XMGXahbsv#Z&RA__p4k9P^gzka)@Uk%rf5~W+R>ruIi3_4} z>+CnuFspYotw)w-9W%lP;UqnBBLe=|I3IRR9)S1fB!GF2H7PdmBDUk^z_U|&c-`6v zu6CUyhLxrGt9cx;Ktc3<(agLE7@sxC!rd*ah2-w{^i*&hZuQT?WS9Xb-%q1X@7@bfEP|jp zp6&ls$ipbRMDQ&w1Y_mz;q}o}FfvdeDqAn;(j6MX!aoYa__c{9Pa+I?!{Y7!_)wNVn;FS|!mU{l-~<*C=IdTcwQor&8&f zF>+LXoeSty8)Id_9?~2VC%v61L6f)1&{e91*ZmSiNo+q=eLp4iw=xp_UaE< zPDyZ<(=@OyABHh*VUY4Si`9$Qsr%^8x+M!;u(2#2GKJH0;Da1tgH0*END{;O^jx^- zyctfvOrb5Y2kTa-FUKBY1CJO-jPtZOm+SF@NMYiU9irA+W8q64R}4E}U^rKrvxL^F zyeyNluQEP~TPy?b4y>yQz9M+EUbkp^>>$;MonF_z{~9YVBa=HL-BJ=Xu2$9NU7c?9 zRzbmM?D2-kGr|x{KOyRB=}!+gwwUwl$^|3A*{z~B)F0|<%GF47=>tLYp(+{g=vV8v zcAV(Rg_U)?H8zP_di(m?nQ*S0lcW7YZ7T+<^9>5&$T7|?ju_QZ^eeiM&9oEOca~q$9^T#-Qk8{p= zIj)1-=00$8ALp)d+q^tWvg6r@9+M%T=y9B<_JOVTg*eV)+mGAi_V|&fflS(& z!HrDbvdm&8j!ZZoj{D;L?xVll7jBmeFGHB{^Q=039~@@oxeW_;{r|_O8PhrPk-%X*u4CCh=;AL!L_?b3nX1LHdAXNC;$vOVs!LxmQ zzIvX=f{FE4Pk$=TYC5J73dCz%Su9w$~a~(NTl3&I9 X=g0Pt>F+X6o`?2-QUPOc%E Date: Wed, 22 Mar 2023 21:50:07 -0700 Subject: [PATCH 10/63] Updated BUILD --- mediapipe/tasks/python/test/vision/BUILD | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index db0bd66b2..a3285a6d6 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -161,20 +161,3 @@ py_test( "@com_google_protobuf//:protobuf_python", ], ) - -py_test( - name = "face_stylizer_test", - srcs = ["face_stylizer_test.py"], - data = [ - "//mediapipe/tasks/testdata/vision:test_images", - "//mediapipe/tasks/testdata/vision:test_models", - ], - deps = [ - "//mediapipe/python:_framework_bindings", - "//mediapipe/tasks/python/core:base_options", - "//mediapipe/tasks/python/test:test_utils", - "//mediapipe/tasks/python/vision:face_stylizer", - "//mediapipe/tasks/python/vision/core:image_processing_options", - "//mediapipe/tasks/python/vision/core:vision_task_running_mode", - ], -) From ca18b95510e243faaa962eeaf8784d37f6b3b484 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 22 Mar 2023 21:50:40 -0700 Subject: [PATCH 11/63] Updated BUILD --- mediapipe/tasks/testdata/vision/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index f77e1780c..097acad43 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -170,7 +170,6 @@ filegroup( "face_detection_short_range.tflite", "face_landmark.tflite", "face_landmark_with_attention.tflite", - "face_stylization_dummy.tflite", "face_landmarker.task", "hair_segmentation.tflite", "hand_landmark_full.tflite", From 29a40413534a9b1a4653ebbc5e77813a602455aa Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 22 Mar 2023 21:51:45 -0700 Subject: [PATCH 12/63] Fixed a typo in docstring --- mediapipe/tasks/python/vision/face_stylizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/vision/face_stylizer.py b/mediapipe/tasks/python/vision/face_stylizer.py index 1393982da..83c46ae63 100644 --- a/mediapipe/tasks/python/vision/face_stylizer.py +++ b/mediapipe/tasks/python/vision/face_stylizer.py @@ -82,7 +82,7 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): def create_from_model_path(cls, model_path: str) -> 'FaceStylizer': """Creates an `FaceStylizer` object from a TensorFlow Lite model and the default `FaceStylizerOptions`. - Note that the created `FaceDetector` instance is in image mode, for + Note that the created `FaceStylizer` instance is in image mode, for stylizing faces on single image inputs. Args: From 1ff80f906cd46b1961b6a9071b91f14f6a5d4ea4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 31 Mar 2023 09:02:35 -0700 Subject: [PATCH 13/63] draw mouth to shoulder line after connection, to align with python viz code PiperOrigin-RevId: 520935390 --- mediapipe/util/pose_util.cc | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mediapipe/util/pose_util.cc b/mediapipe/util/pose_util.cc index dd907fcdd..e5d1f2c9f 100644 --- a/mediapipe/util/pose_util.cc +++ b/mediapipe/util/pose_util.cc @@ -89,19 +89,6 @@ void DrawPose(const mediapipe::NormalizedLandmarkList& pose, int target_width, constexpr int draw_line_width = 5; constexpr int draw_circle_radius = 7; - const int lm = static_cast(PoseLandmarkName::kMouthLeft); - const int rm = static_cast(PoseLandmarkName::kMouthRight); - const int ls = static_cast(PoseLandmarkName::kLeftShoulder); - const int rs = static_cast(PoseLandmarkName::kRightShoulder); - if (visible_landmarks.find(lm) != visible_landmarks.end() && - visible_landmarks.find(rm) != visible_landmarks.end() && - visible_landmarks.find(ls) != visible_landmarks.end() && - visible_landmarks.find(rs) != visible_landmarks.end()) { - cv::line(*image, (visible_landmarks[lm] + visible_landmarks[rm]) * 0.5f, - (visible_landmarks[ls] + visible_landmarks[rs]) * 0.5f, - cv::Scalar(255, 255, 255), draw_line_width); - } - for (int j = 0; j < 35; ++j) { if (visible_landmarks.find(kJointConnection[j][0]) != visible_landmarks.end() && @@ -115,6 +102,19 @@ void DrawPose(const mediapipe::NormalizedLandmarkList& pose, int target_width, } } + const int lm = static_cast(PoseLandmarkName::kMouthLeft); + const int rm = static_cast(PoseLandmarkName::kMouthRight); + const int ls = static_cast(PoseLandmarkName::kLeftShoulder); + const int rs = static_cast(PoseLandmarkName::kRightShoulder); + if (visible_landmarks.find(lm) != visible_landmarks.end() && + visible_landmarks.find(rm) != visible_landmarks.end() && + visible_landmarks.find(ls) != visible_landmarks.end() && + visible_landmarks.find(rs) != visible_landmarks.end()) { + cv::line(*image, (visible_landmarks[lm] + visible_landmarks[rm]) * 0.5f, + (visible_landmarks[ls] + visible_landmarks[rs]) * 0.5f, + cv::Scalar(255, 255, 255), draw_line_width); + } + for (const auto& landmark : visible_landmarks) { cv::circle(*image, landmark.second, draw_circle_radius, cv::Scalar(kJointColorMap[landmark.first][0], From c40e0fb6d59bc755f5289de6efa109bb13351440 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 31 Mar 2023 10:29:34 -0700 Subject: [PATCH 14/63] Internal change PiperOrigin-RevId: 520956127 --- .bazelversion | 2 +- Dockerfile | 2 +- WORKSPACE | 140 ++++++++++++++++---------------- docs/getting_started/install.md | 2 +- 4 files changed, 73 insertions(+), 73 deletions(-) diff --git a/.bazelversion b/.bazelversion index 91ff57278..f3b5af39e 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -5.2.0 +6.1.1 diff --git a/Dockerfile b/Dockerfile index 3df22cc04..03b335823 100644 --- a/Dockerfile +++ b/Dockerfile @@ -61,7 +61,7 @@ RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python # Install bazel -ARG BAZEL_VERSION=5.2.0 +ARG BAZEL_VERSION=6.1.1 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ diff --git a/WORKSPACE b/WORKSPACE index 17e96c0b2..199b6a000 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -54,6 +54,76 @@ load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependen rules_foreign_cc_dependencies() +http_archive( + name = "com_google_protobuf", + sha256 = "87407cd28e7a9c95d9f61a098a53cf031109d451a7763e7dd1253abf8b4df422", + strip_prefix = "protobuf-3.19.1", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"], + patches = [ + "@//third_party:com_google_protobuf_fixes.diff" + ], + patch_args = [ + "-p1", + ], +) + +# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee +# that the target @zlib//:mini_zlib is available +http_archive( + name = "zlib", + build_file = "@//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], + patches = [ + "@//third_party:zlib.diff", + ], + patch_args = [ + "-p1", + ], +) + +# iOS basic build deps. +http_archive( + name = "build_bazel_rules_apple", + sha256 = "3e2c7ae0ddd181c4053b6491dad1d01ae29011bc322ca87eea45957c76d3a0c3", + url = "https://github.com/bazelbuild/rules_apple/releases/download/2.1.0/rules_apple.2.1.0.tar.gz", + patches = [ + # Bypass checking ios unit test runner when building MP ios applications. + "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" + ], + patch_args = [ + "-p1", + ], +) + +load( + "@build_bazel_rules_apple//apple:repositories.bzl", + "apple_rules_dependencies", +) +apple_rules_dependencies() + +load( + "@build_bazel_rules_swift//swift:repositories.bzl", + "swift_rules_dependencies", +) +swift_rules_dependencies() + +load( + "@build_bazel_rules_swift//swift:extras.bzl", + "swift_rules_extra_dependencies", +) +swift_rules_extra_dependencies() + +load( + "@build_bazel_apple_support//lib:repositories.bzl", + "apple_support_dependencies", +) +apple_support_dependencies() + # This is used to select all contents of the archives for CMake-based packages to give CMake access to them. all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])""" @@ -133,19 +203,6 @@ http_archive( urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"], ) -http_archive( - name = "com_google_protobuf", - sha256 = "87407cd28e7a9c95d9f61a098a53cf031109d451a7763e7dd1253abf8b4df422", - strip_prefix = "protobuf-3.19.1", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"], - patches = [ - "@//third_party:com_google_protobuf_fixes.diff" - ], - patch_args = [ - "-p1", - ], -) - load("@//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() @@ -319,63 +376,6 @@ http_archive( ], ) -# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee -# that the target @zlib//:mini_zlib is available -http_archive( - name = "zlib", - build_file = "@//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], - patches = [ - "@//third_party:zlib.diff", - ], - patch_args = [ - "-p1", - ], -) - -# iOS basic build deps. -http_archive( - name = "build_bazel_rules_apple", - sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911", - url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz", - patches = [ - # Bypass checking ios unit test runner when building MP ios applications. - "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" - ], - patch_args = [ - "-p1", - ], -) - -load( - "@build_bazel_rules_apple//apple:repositories.bzl", - "apple_rules_dependencies", -) -apple_rules_dependencies() - -load( - "@build_bazel_rules_swift//swift:repositories.bzl", - "swift_rules_dependencies", -) -swift_rules_dependencies() - -load( - "@build_bazel_rules_swift//swift:extras.bzl", - "swift_rules_extra_dependencies", -) -swift_rules_extra_dependencies() - -load( - "@build_bazel_apple_support//lib:repositories.bzl", - "apple_support_dependencies", -) -apple_support_dependencies() - # More iOS deps. http_archive( diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index cc5c0241d..05f291e5c 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -577,7 +577,7 @@ next section. Option 1. Follow [the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) - to install Bazel 5.2.0 or higher. + to install Bazel 6.1.1 or higher. Option 2. Follow the official [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) From 7f9fd4f154eb1864c908f7a2e8a81646466798c6 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 31 Mar 2023 13:23:43 -0700 Subject: [PATCH 15/63] Add the minimum version number requirement to sounddevice in requirements.txt. PiperOrigin-RevId: 520998845 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 326f21694..85d02d59a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ matplotlib numpy opencv-contrib-python protobuf>=3.11,<4 -sounddevice +sounddevice>=0.4.4 From d9f940f8b2623f9695a4a739d87b27ffdab86247 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 31 Mar 2023 15:17:13 -0700 Subject: [PATCH 16/63] Model Maker object detector change learning_rate_boundaries to learning_rate_epoch_boundaries. PiperOrigin-RevId: 521024056 --- .../vision/object_detector/hyperparameters.py | 38 ++++++++++--------- .../vision/object_detector/object_detector.py | 35 ++++++++++++----- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py index 435dd9745..241104cf8 100644 --- a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py +++ b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py @@ -29,9 +29,9 @@ class HParams(hp.BaseHParams): epochs: Number of training iterations over the dataset. do_fine_tuning: If true, the base module is trained together with the classification layer on top. - learning_rate_boundaries: List of epoch boundaries where - learning_rate_boundaries[i] is the epoch where the learning rate will - decay to learning_rate * learning_rate_decay_multipliers[i]. + learning_rate_epoch_boundaries: List of epoch boundaries where + learning_rate_epoch_boundaries[i] is the epoch where the learning rate + will decay to learning_rate * learning_rate_decay_multipliers[i]. learning_rate_decay_multipliers: List of learning rate multipliers which calculates the learning rate at the ith boundary as learning_rate * learning_rate_decay_multipliers[i]. @@ -43,35 +43,39 @@ class HParams(hp.BaseHParams): epochs: int = 10 # Parameters for learning rate decay - learning_rate_boundaries: List[int] = dataclasses.field( - default_factory=lambda: [5, 8] + learning_rate_epoch_boundaries: List[int] = dataclasses.field( + default_factory=lambda: [] ) learning_rate_decay_multipliers: List[float] = dataclasses.field( - default_factory=lambda: [0.1, 0.01] + default_factory=lambda: [] ) def __post_init__(self): # Validate stepwise learning rate parameters - lr_boundary_len = len(self.learning_rate_boundaries) + lr_boundary_len = len(self.learning_rate_epoch_boundaries) lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers) if lr_boundary_len != lr_decay_multipliers_len: raise ValueError( - "Length of learning_rate_boundaries and ", + "Length of learning_rate_epoch_boundaries and ", "learning_rate_decay_multipliers do not match: ", f"{lr_boundary_len}!={lr_decay_multipliers_len}", ) - # Validate learning_rate_boundaries - if sorted(self.learning_rate_boundaries) != self.learning_rate_boundaries: - raise ValueError( - "learning_rate_boundaries is not in ascending order: ", - self.learning_rate_boundaries, - ) + # Validate learning_rate_epoch_boundaries if ( - self.learning_rate_boundaries - and self.learning_rate_boundaries[-1] > self.epochs + sorted(self.learning_rate_epoch_boundaries) + != self.learning_rate_epoch_boundaries ): raise ValueError( - "Values in learning_rate_boundaries cannot be greater ", "than epochs" + "learning_rate_epoch_boundaries is not in ascending order: ", + self.learning_rate_epoch_boundaries, + ) + if ( + self.learning_rate_epoch_boundaries + and self.learning_rate_epoch_boundaries[-1] > self.epochs + ): + raise ValueError( + "Values in learning_rate_epoch_boundaries cannot be greater ", + "than epochs", ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector.py b/mediapipe/model_maker/python/vision/object_detector/object_detector.py index a6f678cd9..316df85a9 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector.py @@ -57,7 +57,6 @@ class ObjectDetector(classifier.Classifier): self._preprocessor = preprocessor.Preprocessor(model_spec) self._hparams = hparams self._model_options = model_options - self._optimizer = self._create_optimizer() self._is_qat = False @classmethod @@ -104,6 +103,11 @@ class ObjectDetector(classifier.Classifier): train_data: Training data. validation_data: Validation data. """ + self._optimizer = self._create_optimizer( + model_util.get_steps_per_epoch( + self._hparams.steps_per_epoch, + ) + ) self._create_model() self._train_model( train_data, validation_data, preprocessor=self._preprocessor @@ -333,21 +337,34 @@ class ObjectDetector(classifier.Classifier): with open(metadata_file, 'w') as f: f.write(metadata_json) - def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: + def _create_optimizer( + self, steps_per_epoch: int + ) -> tf.keras.optimizers.Optimizer: """Creates an optimizer with learning rate schedule for regular training. Uses Keras PiecewiseConstantDecay schedule by default. + Args: + steps_per_epoch: Steps per epoch to calculate the step boundaries from the + learning_rate_epoch_boundaries + Returns: A tf.keras.optimizer.Optimizer for model training. """ init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 - lr_values = [init_lr] + [ - init_lr * m for m in self._hparams.learning_rate_decay_multipliers - ] - learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay( - self._hparams.learning_rate_boundaries, lr_values - ) + if self._hparams.learning_rate_epoch_boundaries: + lr_values = [init_lr] + [ + init_lr * m for m in self._hparams.learning_rate_decay_multipliers + ] + lr_step_boundaries = [ + steps_per_epoch * epoch_boundary + for epoch_boundary in self._hparams.learning_rate_epoch_boundaries + ] + learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay( + lr_step_boundaries, lr_values + ) + else: + learning_rate = init_lr return tf.keras.optimizers.experimental.SGD( - learning_rate=learning_rate_fn, momentum=0.9 + learning_rate=learning_rate, momentum=0.9 ) From 50a49fd16c812afe6fbdb2800c29a2b06a6e055c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 1 Apr 2023 22:26:07 -0700 Subject: [PATCH 17/63] Internal change PiperOrigin-RevId: 521226781 --- mediapipe/framework/tool/template_parser.cc | 34 ++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index e26275387..f012ac418 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -511,7 +511,7 @@ class TemplateParser::Parser::ParserImpl { DO(ConsumeIdentifier(&field_name)); if (allow_field_number_) { - int32 field_number = std::atoi(field_name.c_str()); // NOLINT + int32_t field_number = std::atoi(field_name.c_str()); // NOLINT if (descriptor->IsExtensionNumber(field_number)) { field = reflection->FindKnownExtensionByNumber(field_number); } else if (descriptor->IsReservedNumber(field_number)) { @@ -765,28 +765,28 @@ class TemplateParser::Parser::ParserImpl { switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: { - int64 value; + int64_t value; DO(ConsumeSignedInteger(&value, kint32max)); - SET_FIELD(Int32, static_cast(value)); + SET_FIELD(Int32, static_cast(value)); break; } case FieldDescriptor::CPPTYPE_UINT32: { - uint64 value; + uint64_t value; DO(ConsumeUnsignedInteger(&value, kuint32max)); - SET_FIELD(UInt32, static_cast(value)); + SET_FIELD(UInt32, static_cast(value)); break; } case FieldDescriptor::CPPTYPE_INT64: { - int64 value; + int64_t value; DO(ConsumeSignedInteger(&value, kint64max)); SET_FIELD(Int64, value); break; } case FieldDescriptor::CPPTYPE_UINT64: { - uint64 value; + uint64_t value; DO(ConsumeUnsignedInteger(&value, kuint64max)); SET_FIELD(UInt64, value); break; @@ -815,7 +815,7 @@ class TemplateParser::Parser::ParserImpl { case FieldDescriptor::CPPTYPE_BOOL: { if (LookingAtType(io::Tokenizer::TYPE_INTEGER)) { - uint64 value; + uint64_t value; DO(ConsumeUnsignedInteger(&value, 1)); SET_FIELD(Bool, value); } else { @@ -836,7 +836,7 @@ class TemplateParser::Parser::ParserImpl { case FieldDescriptor::CPPTYPE_ENUM: { std::string value; - int64 int_value = kint64max; + int64_t int_value = kint64max; const EnumDescriptor* enum_type = field->enum_type(); const EnumValueDescriptor* enum_value = NULL; @@ -1037,7 +1037,7 @@ class TemplateParser::Parser::ParserImpl { // Consumes a uint64 and saves its value in the value parameter. // Returns false if the token is not of type INTEGER. - bool ConsumeUnsignedInteger(uint64* value, uint64 max_value) { + bool ConsumeUnsignedInteger(uint64_t* value, uint64_t max_value) { if (!LookingAtType(io::Tokenizer::TYPE_INTEGER)) { ReportError("Expected integer, got: " + tokenizer_.current().text); return false; @@ -1058,7 +1058,7 @@ class TemplateParser::Parser::ParserImpl { // we actually may consume an additional token (for the minus sign) in this // method. Returns false if the token is not an integer // (signed or otherwise). - bool ConsumeSignedInteger(int64* value, uint64 max_value) { + bool ConsumeSignedInteger(int64_t* value, uint64_t max_value) { bool negative = false; #ifndef PROTO2_OPENSOURCE if (absl::StartsWith(tokenizer_.current().text, "0x")) { @@ -1075,18 +1075,18 @@ class TemplateParser::Parser::ParserImpl { ++max_value; } - uint64 unsigned_value; + uint64_t unsigned_value; DO(ConsumeUnsignedInteger(&unsigned_value, max_value)); if (negative) { - if ((static_cast(kint64max) + 1) == unsigned_value) { + if ((static_cast(kint64max) + 1) == unsigned_value) { *value = kint64min; } else { - *value = -static_cast(unsigned_value); + *value = -static_cast(unsigned_value); } } else { - *value = static_cast(unsigned_value); + *value = static_cast(unsigned_value); } return true; @@ -1094,7 +1094,7 @@ class TemplateParser::Parser::ParserImpl { // Consumes a uint64 and saves its value in the value parameter. // Accepts decimal numbers only, rejects hex or oct numbers. - bool ConsumeUnsignedDecimalInteger(uint64* value, uint64 max_value) { + bool ConsumeUnsignedDecimalInteger(uint64_t* value, uint64_t max_value) { if (!LookingAtType(io::Tokenizer::TYPE_INTEGER)) { ReportError("Expected integer, got: " + tokenizer_.current().text); return false; @@ -1131,7 +1131,7 @@ class TemplateParser::Parser::ParserImpl { // Therefore, we must check both cases here. if (LookingAtType(io::Tokenizer::TYPE_INTEGER)) { // We have found an integer value for the double. - uint64 integer_value; + uint64_t integer_value; DO(ConsumeUnsignedDecimalInteger(&integer_value, kuint64max)); *value = static_cast(integer_value); From 1fa9b2c9852710dcd4c1acca36a1ad4b1137aad8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 2 Apr 2023 08:30:02 -0700 Subject: [PATCH 18/63] Internal change PiperOrigin-RevId: 521279971 --- .../calculators/classification_aggregation_calculator.cc | 2 +- .../components/calculators/embedding_aggregation_calculator.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 145076cd3..01e1292c3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -111,7 +111,7 @@ class ClassificationAggregationCalculator : public Node { private: std::vector head_names_; bool time_aggregation_enabled_; - std::unordered_map> + std::unordered_map> cached_classifications_; ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc index 6e06c4e32..94e0fcb36 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc @@ -83,7 +83,7 @@ class EmbeddingAggregationCalculator : public Node { private: bool time_aggregation_enabled_; - std::unordered_map cached_embeddings_; + std::unordered_map cached_embeddings_; }; absl::Status EmbeddingAggregationCalculator::UpdateContract( From 696bedcaa10ed24f4be77bf57c828197206d3aef Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 2 Apr 2023 17:41:50 -0700 Subject: [PATCH 19/63] Internal change PiperOrigin-RevId: 521327449 --- mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc index 4d3d2cb96..3f0425a69 100644 --- a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc +++ b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc @@ -191,8 +191,9 @@ absl::StatusOr BuildInputImageTensorSpecs( MediaPipeTasksStatus::kInvalidInputTensorDimensionsError); } - size_t byte_depth = - tensor_type == tflite::TensorType_FLOAT32 ? sizeof(float) : sizeof(uint8); + size_t byte_depth = tensor_type == tflite::TensorType_FLOAT32 + ? sizeof(float) + : sizeof(uint8_t); int bytes_size = byte_depth * batch * height * width * depth; // Sanity checks. if (tensor_type == tflite::TensorType_FLOAT32) { From b5bbed8ebb6a1d417413629a9bb535f9da5c77ed Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Apr 2023 02:58:49 -0700 Subject: [PATCH 20/63] Internal change PiperOrigin-RevId: 521406957 --- .../framework/formats/motion/optical_flow_field.cc | 14 +++++++------- .../formats/motion/optical_flow_field_test.cc | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mediapipe/framework/formats/motion/optical_flow_field.cc b/mediapipe/framework/formats/motion/optical_flow_field.cc index 1e6adef48..a96504192 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field.cc @@ -66,12 +66,12 @@ cv::Mat MakeVisualizationHsv(const cv::Mat_& angles, cv::Mat hsv(angles.size(), CV_8UC3); for (int r = 0; r < hsv.rows; ++r) { for (int c = 0; c < hsv.cols; ++c) { - const uint8 hue = static_cast(255.0f * angles(r, c) / 360.0f); - uint8 saturation = 255; + const uint8_t hue = static_cast(255.0f * angles(r, c) / 360.0f); + uint8_t saturation = 255; if (magnitudes(r, c) < max_mag) { - saturation = static_cast(255.0f * magnitudes(r, c) / max_mag); + saturation = static_cast(255.0f * magnitudes(r, c) / max_mag); } - const uint8 value = 255; + const uint8_t value = 255; hsv.at(r, c) = cv::Vec3b(hue, saturation, value); } @@ -282,7 +282,7 @@ void OpticalFlowField::EstimateMotionConsistencyOcclusions( Location OpticalFlowField::FindMotionInconsistentPixels( const OpticalFlowField& forward, const OpticalFlowField& backward, double spatial_distance_threshold) { - const uint8 kOccludedPixelValue = 1; + const uint8_t kOccludedPixelValue = 1; const double threshold_sq = spatial_distance_threshold * spatial_distance_threshold; cv::Mat occluded = cv::Mat::zeros(forward.height(), forward.width(), CV_8UC1); @@ -301,10 +301,10 @@ Location OpticalFlowField::FindMotionInconsistentPixels( if (!in_bounds_in_next_frame || Point2_f(x - round_trip_x, y - round_trip_y).ToVector().Norm2() > threshold_sq) { - occluded.at(y, x) = kOccludedPixelValue; + occluded.at(y, x) = kOccludedPixelValue; } } } - return CreateCvMaskLocation(occluded); + return CreateCvMaskLocation(occluded); } } // namespace mediapipe diff --git a/mediapipe/framework/formats/motion/optical_flow_field_test.cc b/mediapipe/framework/formats/motion/optical_flow_field_test.cc index 521256c48..fdce418fa 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field_test.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field_test.cc @@ -300,15 +300,15 @@ TEST(OpticalFlowField, Occlusions) { for (int y = 0; y < occlusion_mat->rows; ++y) { // Bottom row and pixel at (x, y) = (1, 0) are occluded. if (y == occlusion_mat->rows - 1 || (x == 1 && y == 0)) { - EXPECT_GT(occlusion_mat->at(y, x), 0); + EXPECT_GT(occlusion_mat->at(y, x), 0); } else { - EXPECT_EQ(0, occlusion_mat->at(y, x)); + EXPECT_EQ(0, occlusion_mat->at(y, x)); } // Top row and pixel at (x, y) = (1, 2) are disoccluded. if (y == 0 || (x == 1 && y == 2)) { - EXPECT_GT(disocclusion_mat->at(y, x), 0); + EXPECT_GT(disocclusion_mat->at(y, x), 0); } else { - EXPECT_EQ(0, disocclusion_mat->at(y, x)); + EXPECT_EQ(0, disocclusion_mat->at(y, x)); } } } From 4a490cd27c0b1fef76992418408ef64e2d1509e4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Apr 2023 03:02:30 -0700 Subject: [PATCH 21/63] This CL fixes the multiple typos in the new task api solution PiperOrigin-RevId: 521407588 --- .../google/mediapipe/tasks/core/OutputHandler.java | 6 +++--- .../com/google/mediapipe/tasks/core/TaskInfo.java | 2 +- .../vision/facelandmarker/FaceLandmarkerResult.java | 2 +- .../vision/gesturerecognizer/GestureRecognizer.java | 12 ++++++------ .../tasks/vision/imagesegmenter/ImageSegmenter.java | 8 ++++---- .../interactivesegmenter/InteractiveSegmenter.java | 4 ++-- mediapipe/tasks/metadata/metadata_schema.fbs | 2 +- mediapipe/tasks/python/metadata/metadata.py | 4 ++-- .../metadata/metadata_writers/metadata_info.py | 2 +- .../tasks/python/test/metadata/metadata_test.py | 2 +- .../tasks/python/vision/core/base_vision_task_api.py | 2 +- .../tasks/web/components/containers/matrix.d.ts | 2 +- mediapipe/tasks/web/vision/core/types.d.ts | 2 +- .../tasks/web/vision/core/vision_task_options.d.ts | 2 +- 14 files changed, 26 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java index 49c459ef1..c330b1a56 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java @@ -33,7 +33,7 @@ public class OutputHandler { /** * Interface for the customizable MediaPipe task result listener that can reteive both task result - * objects and the correpsonding input data. + * objects and the corresponding input data. */ public interface ResultListener { void run(OutputT result, InputT input); @@ -90,8 +90,8 @@ public class OutputHandler { } /** - * Sets whether the output handler should react to the timestamp bound changes that are reprsented - * as empty output {@link Packet}s. + * Sets whether the output handler should react to the timestamp bound changes that are + * represented as empty output {@link Packet}s. * * @param handleTimestampBoundChanges A boolean value. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 310f5739c..31af80f5c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -24,7 +24,7 @@ import java.util.ArrayList; import java.util.List; /** - * {@link TaskInfo} contains all needed informaton to initialize a MediaPipe Task {@link + * {@link TaskInfo} contains all needed information to initialize a MediaPipe Task {@link * com.google.mediapipe.framework.Graph}. */ @AutoValue diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java index bafa40e19..7054856fc 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java @@ -108,7 +108,7 @@ public abstract class FaceLandmarkerResult implements TaskResult { public abstract Optional>> faceBlendshapes(); /** - * Optional facial transformation matrix list from cannonical face to the detected face landmarks. + * Optional facial transformation matrix list from canonical face to the detected face landmarks. * The 4x4 facial transformation matrix is represetned as a flat column-major float array. */ public abstract Optional> facialTransformationMatrixes(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index a933d2f65..5b2d7191f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -403,10 +403,10 @@ public final class GestureRecognizer extends BaseVisionTaskApi { public abstract Builder setMinTrackingConfidence(Float value); /** - * Sets the optional {@link ClassifierOptions} controling the canned gestures classifier, such - * as score threshold, allow list and deny list of gestures. The categories for canned gesture - * classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down", - * "Thumb_Up", "Victory", "ILoveYou"] + * Sets the optional {@link ClassifierOptions} controlling the canned gestures classifier, + * such as score threshold, allow list and deny list of gestures. The categories + * for canned gesture classifiers are: ["None", "Closed_Fist", "Open_Palm", + * "Pointing_Up", "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"] * *

TODO Note this option is subject to change, after scoring merging * calculator is implemented. @@ -415,8 +415,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { ClassifierOptions classifierOptions); /** - * Sets the optional {@link ClassifierOptions} controling the custom gestures classifier, such - * as score threshold, allow list and deny list of gestures. + * Sets the optional {@link ClassifierOptions} controlling the custom gestures classifier, + * such as score threshold, allow list and deny list of gestures. * *

TODO Note this option is subject to change, after scoring merging * calculator is implemented. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index f1a08d425..b809ab963 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -302,7 +302,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentWithResultListener(MPImage image) { segmentWithResultListener(image, ImageProcessingOptions.builder().build()); @@ -329,7 +329,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions) { @@ -421,7 +421,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentForVideoWithResultListener(MPImage image, long timestampMs) { segmentForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs); @@ -444,7 +444,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentForVideoWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 8ee6951f8..657716b6b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -327,7 +327,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is - * not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}. + * not created with {@link ResultListener} set in {@link InteractiveSegmenterOptions}. */ public void segmentWithResultListener(MPImage image, RegionOfInterest roi) { segmentWithResultListener(image, roi, ImageProcessingOptions.builder().build()); @@ -357,7 +357,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is - * not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}. + * not created with {@link ResultListener} set in {@link InteractiveSegmenterOptions}. */ public void segmentWithResultListener( MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) { diff --git a/mediapipe/tasks/metadata/metadata_schema.fbs b/mediapipe/tasks/metadata/metadata_schema.fbs index 8fe7a08fa..8660ba38c 100644 --- a/mediapipe/tasks/metadata/metadata_schema.fbs +++ b/mediapipe/tasks/metadata/metadata_schema.fbs @@ -142,7 +142,7 @@ enum AssociatedFileType : byte { // TODO: introduce the ScaNN index file with links once the code // is released. - // Contains on-devide ScaNN index file with LevelDB format. + // Contains on-device ScaNN index file with LevelDB format. // Added in: 1.4.0 SCANN_INDEX_FILE = 6, } diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py index 25d83cae8..6a107c8d8 100644 --- a/mediapipe/tasks/python/metadata/metadata.py +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -121,7 +121,7 @@ class MetadataPopulator(object): Then, pack the metadata and label file into the model as follows. ```python - # Populating a metadata file (or a metadta buffer) and associated files to + # Populating a metadata file (or a metadata buffer) and associated files to a model file: populator = MetadataPopulator.with_model_file(model_file) # For metadata buffer (bytearray read from the metadata file), use: @@ -332,7 +332,7 @@ class MetadataPopulator(object): Raises: IOError: File not found. ValueError: The metadata to be populated is empty. - ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: The metadata does not have the expected flatbuffer identifier. ValueError: Cannot get minimum metadata parser version. ValueError: The number of SubgraphMetadata is not 1. ValueError: The number of input/output tensors does not match the number diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py index f201ab7e0..10b66ff18 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py @@ -559,7 +559,7 @@ class InputTextTensorMd(TensorMd): name: name of the tensor. description: description of what the tensor is. tokenizer_md: information of the tokenizer in the input text tensor, if - any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer + any. Only `RegexTokenizer` [1] is currently supported. If the tokenizer is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to `BertInputTensorsMd` class. [1]: diff --git a/mediapipe/tasks/python/test/metadata/metadata_test.py b/mediapipe/tasks/python/test/metadata/metadata_test.py index d892f1b61..c91bcce6e 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_test.py @@ -388,7 +388,7 @@ class MetadataPopulatorTest(MetadataTest): populator = _metadata.MetadataPopulator.with_model_file(self._model_file) populator.load_metadata_file(self._metadata_file) populator.load_associated_files([self._file1]) - # Suppose to populate self._file2, because it is recorded in the metadta. + # Suppose to populate self._file2, because it is recorded in the metadata. with self.assertRaises(ValueError) as error: populator.populate() self.assertEqual(("File, '{0}', is recorded in the metadata, but has " diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index 0c8262d4b..768d392f1 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -144,7 +144,7 @@ class BaseVisionTaskApi(object): set. By default, it's set to True. Returns: - A normalized rect proto that repesents the image processing options. + A normalized rect proto that represents the image processing options. """ normalized_rect = _NormalizedRect( rotation=0, x_center=0.5, y_center=0.5, width=1, height=1) diff --git a/mediapipe/tasks/web/components/containers/matrix.d.ts b/mediapipe/tasks/web/components/containers/matrix.d.ts index fd4bda4c3..e0bad58c8 100644 --- a/mediapipe/tasks/web/components/containers/matrix.d.ts +++ b/mediapipe/tasks/web/components/containers/matrix.d.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -/** A two-dimenionsal matrix. */ +/** A two-dimensional matrix. */ export declare interface Matrix { /** The number of rows. */ rows: number; diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts index c04366ac0..b48b5045d 100644 --- a/mediapipe/tasks/web/vision/core/types.d.ts +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -19,7 +19,7 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke /** * The segmentation tasks return the segmentation either as a WebGLTexture (when * the output is on GPU) or as a typed JavaScript arrays for CPU-based - * category or confidence masks. `Uint8ClampedArray`s are used to represend + * category or confidence masks. `Uint8ClampedArray`s are used to represent * CPU-based category masks and `Float32Array`s are used for CPU-based * confidence masks. */ diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 72bc2efb1..a45efd6d3 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -27,7 +27,7 @@ export type RunningMode = 'IMAGE'|'VIDEO'; export declare interface VisionTaskOptions extends TaskRunnerOptions { /** * The canvas element to bind textures to. This has to be set for GPU - * processing. The task will initialize a WebGL context and throw an eror if + * processing. The task will initialize a WebGL context and throw an error if * this fails (e.g. if you have already initialized a different type of * context). */ From cfe91f3c8c3c3b1618a72d776db171b50f5d4491 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Apr 2023 04:44:26 -0700 Subject: [PATCH 22/63] Internal change PiperOrigin-RevId: 521424672 --- mediapipe/framework/calculator_graph.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index b49930b7a..06a57fa6d 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -192,8 +192,7 @@ absl::Status CalculatorGraph::InitializeStreams() { auto input_tag_map, tool::TagMap::Create(validated_graph_->Config().input_stream())); for (const auto& stream_name : input_tag_map->Names()) { - RET_CHECK(!mediapipe::ContainsKey(graph_input_streams_, stream_name)) - .SetNoLogging() + RET_CHECK(!graph_input_streams_.contains(stream_name)).SetNoLogging() << "CalculatorGraph Initialization failed, graph input stream \"" << stream_name << "\" was specified twice."; int output_stream_index = validated_graph_->OutputStreamIndex(stream_name); From 9421249de18b7d33415d74694b61063a99101c6f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 3 Apr 2023 20:07:45 +0530 Subject: [PATCH 23/63] Added MPPDetection --- .../tasks/ios/components/containers/BUILD | 7 ++ .../containers/sources/MPPDetection.h | 100 ++++++++++++++++++ .../containers/sources/MPPDetection.m | 70 ++++++++++++ 3 files changed, 177 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPDetection.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPDetection.m diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index 36c8ef2e0..06df9576a 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -44,3 +44,10 @@ objc_library( hdrs = ["sources/MPPEmbeddingResult.h"], deps = [":MPPEmbedding"], ) + +objc_library( + name = "MPPDetection", + srcs = ["sources/MPPDetection.m"], + hdrs = ["sources/MPPDetection.h"], + deps = [":MPPCategory"], +) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.h b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.h new file mode 100644 index 000000000..cc7c2ebeb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.h @@ -0,0 +1,100 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Normalized keypoint represents a point in 2D space with x, y coordinates. x and y are normalized + * to [0.0, 1.0] by the image width and height respectively. + */ +NS_SWIFT_NAME(NormalizedKeypoint) +@interface MPPNormalizedKeypoint : NSObject + +/** The (x,y) coordinates location of the normalized keypoint. */ +@property(nonatomic, readonly) CGPoint location; + +/** The optional label of the normalized keypoint. */ +@property(nonatomic, readonly, nullable) NSString *label; + +/** The optional score of the normalized keypoint. If score is absent, it will be equal to 0.0. */ +@property(nonatomic, readonly) float score; + +/** + * Initializes a new `MPPNormalizedKeypoint` object with the given location, label and score. + * You must pass 0.0 for `score` if it is not present. + * + * @param location The (x,y) coordinates location of the normalized keypoint. + * @param label The optional label of the normalized keypoint. + * @param score The optional score of the normalized keypoint. You must pass 0.0 for score if it + * is not present. + * + * @return An instance of `MPPNormalizedKeypoint` initialized with the given given location, label + * and score. + */ +- (instancetype)initWithLocation:(CGPoint)location + label:(nullable NSString *)label + score:(float)score NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +/** Represents one detected object in the results of `MPPObjectDetector`. */ +NS_SWIFT_NAME(Detection) +@interface MPPDetection : NSObject + +/** An array of `MPPCategory` objects containing the predicted categories. */ +@property(nonatomic, readonly) NSArray *categories; + +/** The bounding box of the detected object. */ +@property(nonatomic, readonly) CGRect boundingBox; + +/** An optional array of `MPPNormalizedKeypoint` objects associated with the detection. Keypoints + * represent interesting points related to the detection. For example, the keypoints represent the + * eyes, ear and mouth from face detection model. Or in the template matching detection, e.g. KNIFT, + * they can represent the feature points for template matching. */ +@property(nonatomic, readonly, nullable) NSArray *keypoints; + +/** + * Initializes a new `MPPDetection` object with the given array of categories, bounding box and + * optional array of keypoints; + * + * @param categories A list of `MPPCategory` objects that contain category name, display name, + * score, and the label index. + * @param boundingBox A `CGRect` that represents the bounding box. + * @param keypoints: An optional array of `MPPNormalizedKeypoint` objects associated with the + * detection. Keypoints represent interesting points related to the detection. For example, the + * keypoints represent the eyes, ear and mouth from face detection model. Or in the template + * matching detection, e.g. KNIFT, they can represent the feature points for template matching. + * + * @return An instance of `MPPDetection` initialized with the given array of categories, bounding + * box and `nil` keypoints. + */ +- (instancetype)initWithCategories:(NSArray *)categories + boundingBox:(CGRect)boundingBox + keypoints:(nullable NSArray *)keypoints + NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m new file mode 100644 index 000000000..42259ffde --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m @@ -0,0 +1,70 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPDetection.h" + +@implementation MPPNormalizedKeypoint + +- (instancetype)initWithLocation:(CGPoint)location + label:(nullable NSString *)label + score:(float)score { + self = [super init]; + if (self) { + _location = location; + _label = label; + _score = score; + } + return self; +} + +// TODO: Implement hash + +- (BOOL)isEqual:(nullable id)object { + if (!object) { + return NO; + } + + if (self == object) { + return YES; + } + + if (![object isKindOfClass:[MPPNormalizedKeypoint class]]) { + return NO; + } + + MPPNormalizedKeypoint *otherKeypoint = (MPPNormalizedKeypoint *)object; + + if (CGPointEqualToPoint(self.location, otherKeypoint.location) && + (self.label == otherKeypoint.label) && (self.score == otherKeypoint.score)) { + return YES; + } +} + +@end + +@implementation MPPDetection + +- (instancetype)initWithCategories:(NSArray *)categories + boundingBox:(CGRect)boundingBox + keypoints:(nullable NSArray *)keypoints { + self = [super init]; + if (self) { + _categories = categories; + _boundingBox = boundingBox; + _keypoints = keypoints; + } + return self; +} + +@end \ No newline at end of file From 4943029d62d42b26eeb4589b0d963b853f69e98d Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 3 Apr 2023 20:11:44 +0530 Subject: [PATCH 24/63] Added MPPDetectionHelpers --- .../ios/components/containers/utils/BUILD | 11 +++ .../utils/sources/MPPDetection+Helpers.h | 26 ++++++ .../utils/sources/MPPDetection+Helpers.mm | 82 +++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD index 64ca29b88..3520740b0 100644 --- a/mediapipe/tasks/ios/components/containers/utils/BUILD +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -61,3 +61,14 @@ objc_library( "//mediapipe/tasks/ios/components/containers:MPPEmbeddingResult", ], ) + +objc_library( + name = "MPPDetectionHelpers", + srcs = ["sources/MPPDetection+Helpers.mm"], + hdrs = ["sources/MPPDetection+Helpers.h"], + deps = [ + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/containers:MPPDetection", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h new file mode 100644 index 000000000..c06c04d3e --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/detection.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPDetection.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPDetection (Helpers) + ++ (MPPDetection *)detectionWithProto:(const mediapipe::Detection &)detectionProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm new file mode 100644 index 000000000..e5cc8dc03 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm @@ -0,0 +1,82 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +static const NSInteger kDefaultCategoryIndex = -1; + +namespace { +using DetectionProto = ::mediapipe::Detection; +using BoundingBoxProto = ::mediapipe::LocationData::BoundingBox; +} // namespace + +@implementation MPPDetection (Helpers) + ++ (MPPDetection *)detectionWithProto:(const DetectionProto &)detectionProto { + NSMutableArray *categories = + [NSMutableArray arrayWithCapacity:(NSUInteger)detectionProto.score_size()]; + + for (int idx = 0; idx < detectionProto.score_size(); ++idx) { + NSInteger categoryIndex = + detectionProto.label_id_size() > idx ? detectionProto.label_id(idx) : kDefaultCategoryIndex; + NSString *categoryName = detectionProto.label_size() > idx + ? [NSString stringWithCppString:detectionProto.label(idx)] + : nil; + + NSString *displayName = detectionProto.display_name_size() > idx + ? [NSString stringWithCppString:detectionProto.display_name(idx)] + : nil; + + [categories addObject:[[MPPCategory alloc] initWithIndex:categoryIndex + score:detectionProto.score(idx) + categoryName:categoryName + displayName:displayName]]; + } + + CGRect boundingBox = CGRectZero; + + if (detectionProto.location_data().has_bounding_box()) { + const BoundingBoxProto &boundingBoxProto = detectionProto.location_data().bounding_box(); + boundingBox.origin.x = boundingBoxProto.xmin(); + boundingBox.origin.y = boundingBoxProto.ymin(); + boundingBox.size.width = boundingBoxProto.width(); + boundingBox.size.height = boundingBoxProto.height(); + } + + NSMutableArray *normalizedKeypoints; + + if (!detectionProto.location_data().relative_keypoints().empty()) { + normalizedKeypoints = [NSMutableArray + arrayWithCapacity:(NSUInteger)detectionProto.location_data().relative_keypoints_size()]; + for (const auto &keypoint : detectionProto.location_data().relative_keypoints()) { + NSString *label = keypoint.has_keypoint_label() + ? [NSString stringWithCppString:keypoint.keypoint_label()] + : nil; + CGPoint location = CGPointMake(keypoint.x(), keypoint.y()); + float score = keypoint.has_score() ? keypoint.score() : 0.0f; + + [normalizedKeypoints addObject:[[MPPNormalizedKeypoint alloc] initWithLocation:location + label:label + score:score]]; + } + } + + return [[MPPDetection alloc] initWithCategories:categories + boundingBox:boundingBox + keypoints:normalizedKeypoints]; +} + +@end From 67fcf9196eaf7e8a3f0369fd4c09234b58300583 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 3 Apr 2023 20:14:26 +0530 Subject: [PATCH 25/63] Added MPPObjectDetectionResult --- .../tasks/ios/vision/object_detector/BUILD | 28 +++++++++++ .../sources/MPPObjectDetectionResult.h | 47 +++++++++++++++++++ .../sources/MPPObjectDetectionResult.m | 28 +++++++++++ 3 files changed, 103 insertions(+) create mode 100644 mediapipe/tasks/ios/vision/object_detector/BUILD create mode 100644 mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h create mode 100644 mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m diff --git a/mediapipe/tasks/ios/vision/object_detector/BUILD b/mediapipe/tasks/ios/vision/object_detector/BUILD new file mode 100644 index 000000000..218e1a8cc --- /dev/null +++ b/mediapipe/tasks/ios/vision/object_detector/BUILD @@ -0,0 +1,28 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPObjectDetectionResult", + srcs = ["sources/MPPObjectDetectionResult.m"], + hdrs = ["sources/MPPObjectDetectionResult.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPDetection", + "//mediapipe/tasks/ios/core:MPPTaskResult", + ], +) + diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h new file mode 100644 index 000000000..6e4921efc --- /dev/null +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h @@ -0,0 +1,47 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPDetection.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the detection results generated by `MPPObjectDetector`. */ +NS_SWIFT_NAME(ObjectDetectionResult) +@interface MPPObjectDetectionResult : MPPTaskResult + +/** The array of `MPPDetection` objects each of which has a bounding box that is expressed in the + * unrotated input frame of reference coordinates system, i.e. in `[0,image_width) x + * [0,image_height)`, which are the dimensions of the underlying image data. */ +@property(nonatomic, readonly) NSArray *detections; + +/** + * Initializes a new `MPPObjectDetectionResult` with the given array of detections and timestamp (in + * milliseconds). + * + * @param detections An array of `MPPDetection` objects each of which has a bounding box that is + * expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width) + * x [0,image_height)`, which are the dimensions of the underlying image data. + * @param timestampMs The timestamp for this result. + * + * @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections + * and timestamp (in milliseconds). + */ +- (instancetype)initWithDetections:(NSArray *)detections + timestampMs:(NSInteger)timestampMs; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m new file mode 100644 index 000000000..ac24c19fa --- /dev/null +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" + +@implementation MPPObjectDetectionResult + +- (instancetype)initWithDetections:(NSArray *)detections + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; + if (self) { + _detections = detections; + } + return self; +} + +@end From 1ab9b138efe76cd10e041662fe6e5c33cd42c86e Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 3 Apr 2023 20:14:41 +0530 Subject: [PATCH 26/63] Added MPPObjectDetectorOptions --- .../tasks/ios/vision/object_detector/BUILD | 10 +++ .../sources/MPPObjectDetectorOptions.h | 71 +++++++++++++++++++ .../sources/MPPObjectDetectorOptions.m | 41 +++++++++++ 3 files changed, 122 insertions(+) create mode 100644 mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h create mode 100644 mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m diff --git a/mediapipe/tasks/ios/vision/object_detector/BUILD b/mediapipe/tasks/ios/vision/object_detector/BUILD index 218e1a8cc..f1325b050 100644 --- a/mediapipe/tasks/ios/vision/object_detector/BUILD +++ b/mediapipe/tasks/ios/vision/object_detector/BUILD @@ -26,3 +26,13 @@ objc_library( ], ) +objc_library( + name = "MPPObjectDetectorOptions", + srcs = ["sources/MPPObjectDetectorOptions.m"], + hdrs = ["sources/MPPObjectDetectorOptions.h"], + deps = [ + ":MPPObjectDetectionResult", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/vision/core:MPPRunningMode", + ], +) diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h new file mode 100644 index 000000000..bf92e9a44 --- /dev/null +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h @@ -0,0 +1,71 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h" +#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Options for setting up a `MPPObjectDetector`. */ +NS_SWIFT_NAME(ObjectDetectorOptions) +@interface MPPObjectDetectorOptions : MPPTaskOptions + +@property(nonatomic) MPPRunningMode runningMode; + +/** + * The user-defined result callback for processing live stream data. The result callback should only + * be specified when the running mode is set to the live stream mode. + * TODO: Add parameter `MPPImage` in the callback. + */ +@property(nonatomic, copy) void (^completion) + (MPPObjectDetectionResult *result, NSInteger timestampMs, NSError *error); + +/** + * The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults + * to English. + */ +@property(nonatomic, copy) NSString *displayNamesLocale; + +/** + * The maximum number of top-scored classification results to return. If < 0, all available results + * will be returned. If 0, an invalid argument error is returned. + */ +@property(nonatomic) NSInteger maxResults; + +/** + * Score threshold to override the one provided in the model metadata (if any). Results below this + * value are rejected. + */ +@property(nonatomic) float scoreThreshold; + +/** + * The allowlist of category names. If non-empty, detection results whose category name is not in + * this set will be filtered out. Duplicate or unknown category names are ignored. Mutually + * exclusive with categoryDenylist. + */ +@property(nonatomic, copy) NSArray *categoryAllowlist; + +/** + * The denylist of category names. If non-empty, detection results whose category name is in this + * set will be filtered out. Duplicate or unknown category names are ignored. Mutually exclusive + * with categoryAllowlist. + */ +@property(nonatomic, copy) NSArray *categoryDenylist; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m new file mode 100644 index 000000000..73f8ce5b5 --- /dev/null +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m @@ -0,0 +1,41 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h" + +@implementation MPPObjectDetectorOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _maxResults = -1; + _scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPObjectDetectorOptions *objectDetectorOptions = [super copyWithZone:zone]; + + objectDetectorOptions.scoreThreshold = self.scoreThreshold; + objectDetectorOptions.maxResults = self.maxResults; + objectDetectorOptions.categoryDenylist = self.categoryDenylist; + objectDetectorOptions.categoryAllowlist = self.categoryAllowlist; + objectDetectorOptions.displayNamesLocale = self.displayNamesLocale; + objectDetectorOptions.completion = self.completion; + + return objectDetectorOptions; +} + +@end From 048cc51e136dc5d2f3a4839ad0ea4c2422593c7b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 3 Apr 2023 20:15:35 +0530 Subject: [PATCH 27/63] Added new line --- .../tasks/ios/components/containers/sources/MPPDetection.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m index 42259ffde..a4dd642e2 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m @@ -67,4 +67,4 @@ return self; } -@end \ No newline at end of file +@end From e84799ee37f55d20944a48d4e06092654226d46f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 3 Apr 2023 09:43:05 -0700 Subject: [PATCH 28/63] Internal change PiperOrigin-RevId: 521483663 --- mediapipe/calculators/internal/BUILD | 17 +- mediapipe/calculators/video/BUILD | 106 +++-------- mediapipe/examples/desktop/autoflip/BUILD | 27 +-- .../desktop/autoflip/calculators/BUILD | 114 +++--------- .../examples/desktop/autoflip/quality/BUILD | 39 +---- mediapipe/framework/formats/motion/BUILD | 10 +- mediapipe/framework/stream_handler/BUILD | 42 +---- mediapipe/framework/tool/mediapipe_proto.bzl | 1 - mediapipe/gpu/BUILD | 38 ++-- .../graphs/iris_tracking/calculators/BUILD | 30 +--- mediapipe/util/tracking/BUILD | 165 ++---------------- third_party/halide.BUILD | 2 +- 12 files changed, 109 insertions(+), 482 deletions(-) diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index 8647e3f3f..a92a2f252 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -12,25 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -proto_library( +mediapipe_proto_library( name = "callback_packet_calculator_proto", srcs = ["callback_packet_calculator.proto"], visibility = ["//mediapipe/framework:__subpackages__"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "callback_packet_calculator_cc_proto", - srcs = ["callback_packet_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//mediapipe/framework:__subpackages__"], - deps = [":callback_packet_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_library( diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index e4aa1bff8..7245b13c2 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_binary_graph", @@ -23,28 +23,35 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -proto_library( +mediapipe_proto_library( name = "flow_to_image_calculator_proto", srcs = ["flow_to_image_calculator.proto"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "opencv_video_encoder_calculator_proto", srcs = ["opencv_video_encoder_calculator.proto"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "motion_analysis_calculator_proto", srcs = ["motion_analysis_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:motion_analysis_proto", ], ) -proto_library( +mediapipe_proto_library( name = "flow_packager_calculator_proto", srcs = ["flow_packager_calculator.proto"], deps = [ @@ -54,114 +61,45 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "box_tracker_calculator_proto", srcs = ["box_tracker_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_tracker_proto", ], ) -proto_library( +mediapipe_proto_library( name = "tracked_detection_manager_calculator_proto", srcs = ["tracked_detection_manager_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_proto", ], ) -proto_library( +mediapipe_proto_library( name = "box_detector_calculator_proto", srcs = ["box_detector_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_detector_proto", ], ) -proto_library( +mediapipe_proto_library( name = "video_pre_stream_calculator_proto", srcs = ["video_pre_stream_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "motion_analysis_calculator_cc_proto", - srcs = ["motion_analysis_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:motion_analysis_cc_proto", - ], - deps = [":motion_analysis_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "flow_packager_calculator_cc_proto", - srcs = ["flow_packager_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:flow_packager_cc_proto", - ], - deps = [":flow_packager_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "box_tracker_calculator_cc_proto", - srcs = ["box_tracker_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:box_tracker_cc_proto", - ], - deps = [":box_tracker_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tracked_detection_manager_calculator_cc_proto", - srcs = ["tracked_detection_manager_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:tracked_detection_manager_config_cc_proto", - ], - deps = [":tracked_detection_manager_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "box_detector_calculator_cc_proto", - srcs = ["box_detector_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:box_detector_cc_proto", - ], - deps = [":box_detector_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "video_pre_stream_calculator_cc_proto", - srcs = ["video_pre_stream_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - deps = [":video_pre_stream_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "flow_to_image_calculator_cc_proto", - srcs = ["flow_to_image_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - deps = [":flow_to_image_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "opencv_video_encoder_calculator_cc_proto", - srcs = ["opencv_video_encoder_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - deps = [":opencv_video_encoder_calculator_proto"], -) - cc_library( name = "flow_to_image_calculator", srcs = ["flow_to_image_calculator.cc"], diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 340205caa..fe994e2e0 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -1,4 +1,4 @@ -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") # Copyright 2019 The MediaPipe Authors. # @@ -22,7 +22,7 @@ package(default_visibility = [ "//photos/editing/mobile/mediapipe/proto:__subpackages__", ]) -proto_library( +mediapipe_proto_library( name = "autoflip_messages_proto", srcs = ["autoflip_messages.proto"], deps = [ @@ -30,29 +30,6 @@ proto_library( ], ) -java_lite_proto_library( - name = "autoflip_messages_java_proto_lite", - visibility = [ - "//java/com/google/android/apps/photos:__subpackages__", - "//javatests/com/google/android/apps/photos:__subpackages__", - ], - deps = [ - ":autoflip_messages_proto", - ], -) - -mediapipe_cc_proto_library( - name = "autoflip_messages_cc_proto", - srcs = ["autoflip_messages.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = [ - "//mediapipe/examples:__subpackages__", - "//photos/editing/mobile/mediapipe/calculators:__pkg__", - "//photos/editing/mobile/mediapipe/calculators:__subpackages__", - ], - deps = [":autoflip_messages_proto"], -) - cc_binary( name = "run_autoflip", data = [ diff --git a/mediapipe/examples/desktop/autoflip/calculators/BUILD b/mediapipe/examples/desktop/autoflip/calculators/BUILD index 18f56cc4f..a3b2ace2a 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/BUILD +++ b/mediapipe/examples/desktop/autoflip/calculators/BUILD @@ -1,4 +1,4 @@ -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") # Copyright 2019 The MediaPipe Authors. # @@ -40,22 +40,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "border_detection_calculator_proto", srcs = ["border_detection_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "border_detection_calculator_cc_proto", - srcs = ["border_detection_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":border_detection_calculator_proto"], -) - cc_library( name = "content_zooming_calculator_state", hdrs = ["content_zooming_calculator_state.h"], @@ -85,27 +79,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "content_zooming_calculator_proto", srcs = ["content_zooming_calculator.proto"], - deps = [ - "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_proto", - "//mediapipe/framework:calculator_proto", - ], -) - -mediapipe_cc_proto_library( - name = "content_zooming_calculator_cc_proto", - srcs = ["content_zooming_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], visibility = [ "//mediapipe/examples:__subpackages__", ], deps = [ - ":content_zooming_calculator_proto", + "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", ], ) @@ -177,23 +160,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "video_filtering_calculator_proto", srcs = ["video_filtering_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "video_filtering_calculator_cc_proto", - srcs = ["video_filtering_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":video_filtering_calculator_proto"], -) - cc_test( name = "video_filtering_calculator_test", srcs = ["video_filtering_calculator_test.cc"], @@ -209,27 +185,17 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "scene_cropping_calculator_proto", srcs = ["scene_cropping_calculator.proto"], visibility = ["//visibility:public"], deps = [ "//mediapipe/examples/desktop/autoflip/quality:cropping_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "scene_cropping_calculator_cc_proto", - srcs = ["scene_cropping_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip/quality:cropping_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":scene_cropping_calculator_proto"], -) - cc_library( name = "scene_cropping_calculator", srcs = ["scene_cropping_calculator.cc"], @@ -296,26 +262,17 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "signal_fusing_calculator_proto", srcs = ["signal_fusing_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ "//mediapipe/examples/desktop/autoflip:autoflip_messages_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "signal_fusing_calculator_cc_proto", - srcs = ["signal_fusing_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":signal_fusing_calculator_proto"], -) - cc_test( name = "signal_fusing_calculator_test", srcs = ["signal_fusing_calculator_test.cc"], @@ -353,18 +310,14 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "shot_boundary_calculator_proto", srcs = ["shot_boundary_calculator.proto"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "shot_boundary_calculator_cc_proto", - srcs = ["shot_boundary_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":shot_boundary_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_test( @@ -413,26 +366,17 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "face_to_region_calculator_proto", srcs = ["face_to_region_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ "//mediapipe/examples/desktop/autoflip/quality:visual_scorer_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "face_to_region_calculator_cc_proto", - srcs = ["face_to_region_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip/quality:visual_scorer_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":face_to_region_calculator_proto"], -) - cc_test( name = "face_to_region_calculator_test", srcs = ["face_to_region_calculator_test.cc"], @@ -454,22 +398,16 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "localization_to_region_calculator_proto", srcs = ["localization_to_region_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "localization_to_region_calculator_cc_proto", - srcs = ["localization_to_region_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":localization_to_region_calculator_proto"], -) - cc_library( name = "localization_to_region_calculator", srcs = ["localization_to_region_calculator.cc"], diff --git a/mediapipe/examples/desktop/autoflip/quality/BUILD b/mediapipe/examples/desktop/autoflip/quality/BUILD index 0b5970ee9..20e286107 100644 --- a/mediapipe/examples/desktop/autoflip/quality/BUILD +++ b/mediapipe/examples/desktop/autoflip/quality/BUILD @@ -1,4 +1,4 @@ -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") # Copyright 2019 The MediaPipe Authors. # @@ -20,7 +20,7 @@ package(default_visibility = [ "//mediapipe/examples:__subpackages__", ]) -proto_library( +mediapipe_proto_library( name = "cropping_proto", srcs = ["cropping.proto"], deps = [ @@ -29,41 +29,18 @@ proto_library( ], ) -mediapipe_cc_proto_library( - name = "cropping_cc_proto", - srcs = ["cropping.proto"], - cc_deps = [ - ":kinematic_path_solver_cc_proto", - "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", - ], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":cropping_proto"], -) - -proto_library( +mediapipe_proto_library( name = "kinematic_path_solver_proto", srcs = ["kinematic_path_solver.proto"], -) - -mediapipe_cc_proto_library( - name = "kinematic_path_solver_cc_proto", - srcs = ["kinematic_path_solver.proto"], visibility = [ "//mediapipe/examples:__subpackages__", ], - deps = [":kinematic_path_solver_proto"], ) -proto_library( +mediapipe_proto_library( name = "focus_point_proto", srcs = ["focus_point.proto"], -) - -mediapipe_cc_proto_library( - name = "focus_point_cc_proto", - srcs = ["focus_point.proto"], visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":focus_point_proto"], ) cc_library( @@ -333,16 +310,10 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "visual_scorer_proto", srcs = ["visual_scorer.proto"], -) - -mediapipe_cc_proto_library( - name = "visual_scorer_cc_proto", - srcs = ["visual_scorer.proto"], visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":visual_scorer_proto"], ) cc_library( diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index c9bb8b4ff..919b82406 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -16,23 +16,17 @@ # Description: # Working with dense optical flow in mediapipe. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -proto_library( +mediapipe_proto_library( name = "optical_flow_field_data_proto", srcs = ["optical_flow_field_data.proto"], ) -mediapipe_cc_proto_library( - name = "optical_flow_field_data_cc_proto", - srcs = ["optical_flow_field_data.proto"], - deps = [":optical_flow_field_data_proto"], -) - cc_library( name = "optical_flow_field", srcs = ["optical_flow_field.cc"], diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 68a9af52d..8b54ade8b 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) @@ -22,56 +22,32 @@ package( features = ["-layering_check"], ) -proto_library( +mediapipe_proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], + alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "fixed_size_input_stream_handler_proto", srcs = ["fixed_size_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], + alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "sync_set_input_stream_handler_proto", srcs = ["sync_set_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], + alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "timestamp_align_input_stream_handler_proto", srcs = ["timestamp_align_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], -) - -mediapipe_cc_proto_library( - name = "default_input_stream_handler_cc_proto", - srcs = ["default_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":default_input_stream_handler_proto"], -) - -mediapipe_cc_proto_library( - name = "fixed_size_input_stream_handler_cc_proto", - srcs = ["fixed_size_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":fixed_size_input_stream_handler_proto"], -) - -mediapipe_cc_proto_library( - name = "sync_set_input_stream_handler_cc_proto", - srcs = ["sync_set_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":sync_set_input_stream_handler_proto"], -) - -mediapipe_cc_proto_library( - name = "timestamp_align_input_stream_handler_cc_proto", - srcs = ["timestamp_align_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":timestamp_align_input_stream_handler_proto"], + alwayslink = 1, ) cc_library( diff --git a/mediapipe/framework/tool/mediapipe_proto.bzl b/mediapipe/framework/tool/mediapipe_proto.bzl index 7ed87aba9..527774ff3 100644 --- a/mediapipe/framework/tool/mediapipe_proto.bzl +++ b/mediapipe/framework/tool/mediapipe_proto.bzl @@ -90,7 +90,6 @@ def mediapipe_proto_library_impl( visibility = visibility, testonly = testonly, compatible_with = compatible_with, - alwayslink = alwayslink, )) if def_cc_proto: diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 2f7f7ec33..ca2912ac3 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -15,7 +15,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("//mediapipe/gpu:metal.bzl", "metal_library") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") load("//mediapipe/framework:more_selects.bzl", "more_selects") @@ -555,7 +555,10 @@ mediapipe_proto_library( name = "gl_context_options_proto", srcs = ["gl_context_options.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) # This is a hack needed to work around some issues with strict hdrs_check. @@ -929,6 +932,7 @@ mediapipe_proto_library( srcs = ["gl_animation_overlay_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) @@ -939,6 +943,7 @@ mediapipe_proto_library( visibility = ["//visibility:public"], deps = [ ":scale_mode_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) @@ -982,26 +987,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ ":scale_mode_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "gl_surface_sink_calculator_cc_proto", - srcs = ["gl_surface_sink_calculator.proto"], - cc_deps = [ - ":scale_mode_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":gl_surface_sink_calculator_proto"], -) - ### Metal calculators metal_library( @@ -1017,21 +1012,14 @@ objc_library( deps = [":simple_shaders_mtl"], ) -proto_library( +mediapipe_proto_library( name = "copy_calculator_proto", srcs = ["copy_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "copy_calculator_cc_proto", - srcs = ["copy_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", ], - visibility = ["//visibility:public"], - deps = [":copy_calculator_proto"], ) objc_library( diff --git a/mediapipe/graphs/iris_tracking/calculators/BUILD b/mediapipe/graphs/iris_tracking/calculators/BUILD index f5124b464..9ddce7f36 100644 --- a/mediapipe/graphs/iris_tracking/calculators/BUILD +++ b/mediapipe/graphs/iris_tracking/calculators/BUILD @@ -12,33 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -proto_library( +mediapipe_proto_library( name = "iris_to_render_data_calculator_proto", srcs = ["iris_to_render_data_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -mediapipe_cc_proto_library( - name = "iris_to_render_data_calculator_cc_proto", - srcs = ["iris_to_render_data_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - "//mediapipe/util:render_data_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":iris_to_render_data_calculator_proto"], -) - cc_library( name = "iris_to_render_data_calculator", srcs = ["iris_to_render_data_calculator.cc"], @@ -56,25 +45,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "iris_to_depth_calculator_proto", srcs = ["iris_to_depth_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "iris_to_depth_calculator_cc_proto", - srcs = ["iris_to_depth_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":iris_to_depth_calculator_proto"], -) - cc_library( name = "iris_to_depth_calculator", srcs = ["iris_to_depth_calculator.cc"], diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 816af2533..5a271ffac 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -13,24 +13,24 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -proto_library( +mediapipe_proto_library( name = "tone_models_proto", srcs = ["tone_models.proto"], ) -proto_library( +mediapipe_proto_library( name = "tone_estimation_proto", srcs = ["tone_estimation.proto"], deps = [":tone_models_proto"], ) -proto_library( +mediapipe_proto_library( name = "region_flow_computation_proto", srcs = ["region_flow_computation.proto"], deps = [ @@ -38,17 +38,17 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "motion_saliency_proto", srcs = ["motion_saliency.proto"], ) -proto_library( +mediapipe_proto_library( name = "motion_estimation_proto", srcs = ["motion_estimation.proto"], ) -proto_library( +mediapipe_proto_library( name = "motion_analysis_proto", srcs = ["motion_analysis.proto"], deps = [ @@ -58,33 +58,33 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "region_flow_proto", srcs = ["region_flow.proto"], ) -proto_library( +mediapipe_proto_library( name = "motion_models_proto", srcs = ["motion_models.proto"], ) -proto_library( +mediapipe_proto_library( name = "camera_motion_proto", srcs = ["camera_motion.proto"], deps = [":motion_models_proto"], ) -proto_library( +mediapipe_proto_library( name = "push_pull_filtering_proto", srcs = ["push_pull_filtering.proto"], ) -proto_library( +mediapipe_proto_library( name = "frame_selection_solution_evaluator_proto", srcs = ["frame_selection_solution_evaluator.proto"], ) -proto_library( +mediapipe_proto_library( name = "frame_selection_proto", srcs = ["frame_selection.proto"], deps = [ @@ -94,7 +94,7 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "flow_packager_proto", srcs = ["flow_packager.proto"], deps = [ @@ -103,7 +103,7 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "tracking_proto", srcs = ["tracking.proto"], deps = [ @@ -111,18 +111,18 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "box_tracker_proto", srcs = ["box_tracker.proto"], deps = [":tracking_proto"], ) -proto_library( +mediapipe_proto_library( name = "tracked_detection_manager_config_proto", srcs = ["tracked_detection_manager_config.proto"], ) -proto_library( +mediapipe_proto_library( name = "box_detector_proto", srcs = ["box_detector.proto"], deps = [ @@ -131,135 +131,6 @@ proto_library( ], ) -mediapipe_cc_proto_library( - name = "tone_models_cc_proto", - srcs = ["tone_models.proto"], - deps = [":tone_models_proto"], -) - -mediapipe_cc_proto_library( - name = "tone_estimation_cc_proto", - srcs = ["tone_estimation.proto"], - cc_deps = [":tone_models_cc_proto"], - deps = [":tone_estimation_proto"], -) - -mediapipe_cc_proto_library( - name = "region_flow_computation_cc_proto", - srcs = ["region_flow_computation.proto"], - cc_deps = [ - ":tone_estimation_cc_proto", - ":tone_models_cc_proto", - ], - deps = [":region_flow_computation_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_saliency_cc_proto", - srcs = ["motion_saliency.proto"], - deps = [":motion_saliency_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_estimation_cc_proto", - srcs = ["motion_estimation.proto"], - deps = [":motion_estimation_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_analysis_cc_proto", - srcs = ["motion_analysis.proto"], - cc_deps = [ - ":motion_estimation_cc_proto", - ":motion_saliency_cc_proto", - ":region_flow_computation_cc_proto", - ], - deps = [":motion_analysis_proto"], -) - -mediapipe_cc_proto_library( - name = "region_flow_cc_proto", - srcs = ["region_flow.proto"], - cc_deps = [":motion_models_cc_proto"], - deps = [":region_flow_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_models_cc_proto", - srcs = ["motion_models.proto"], - deps = [":motion_models_proto"], -) - -mediapipe_cc_proto_library( - name = "camera_motion_cc_proto", - srcs = ["camera_motion.proto"], - cc_deps = [":motion_models_cc_proto"], - deps = [":camera_motion_proto"], -) - -mediapipe_cc_proto_library( - name = "push_pull_filtering_cc_proto", - srcs = ["push_pull_filtering.proto"], - deps = [":push_pull_filtering_proto"], -) - -mediapipe_cc_proto_library( - name = "frame_selection_solution_evaluator_cc_proto", - srcs = ["frame_selection_solution_evaluator.proto"], - deps = [":frame_selection_solution_evaluator_proto"], -) - -mediapipe_cc_proto_library( - name = "frame_selection_cc_proto", - srcs = ["frame_selection.proto"], - cc_deps = [ - ":camera_motion_cc_proto", - ":frame_selection_solution_evaluator_cc_proto", - ":region_flow_cc_proto", - ], - deps = [":frame_selection_proto"], -) - -mediapipe_cc_proto_library( - name = "flow_packager_cc_proto", - srcs = ["flow_packager.proto"], - cc_deps = [ - ":motion_models_cc_proto", - ":region_flow_cc_proto", - ], - deps = [":flow_packager_proto"], -) - -mediapipe_cc_proto_library( - name = "tracking_cc_proto", - srcs = ["tracking.proto"], - cc_deps = [":motion_models_cc_proto"], - deps = [":tracking_proto"], -) - -mediapipe_cc_proto_library( - name = "box_tracker_cc_proto", - srcs = ["box_tracker.proto"], - cc_deps = [":tracking_cc_proto"], - deps = [":box_tracker_proto"], -) - -mediapipe_cc_proto_library( - name = "tracked_detection_manager_config_cc_proto", - srcs = ["tracked_detection_manager_config.proto"], - deps = [":tracked_detection_manager_config_proto"], -) - -mediapipe_cc_proto_library( - name = "box_detector_cc_proto", - srcs = ["box_detector.proto"], - cc_deps = [ - ":box_tracker_cc_proto", - ":region_flow_cc_proto", - ], - deps = [":box_detector_proto"], -) - cc_library( name = "motion_models", srcs = ["motion_models.cc"], diff --git a/third_party/halide.BUILD b/third_party/halide.BUILD index 02e701585..677fa9f38 100644 --- a/third_party/halide.BUILD +++ b/third_party/halide.BUILD @@ -43,8 +43,8 @@ cc_library( name = "lib_halide_static", srcs = select({ "@halide//:halide_config_windows_x86_64": [ - "lib/Release/Halide.lib", "bin/Release/Halide.dll", + "lib/Release/Halide.lib", ], "//conditions:default": [ "lib/libHalide.a", From 7c7eb74ef2ec0d50228fe7e6519cbb0b1eb9d670 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Apr 2023 15:58:04 -0700 Subject: [PATCH 29/63] Internal change PiperOrigin-RevId: 521586389 --- mediapipe/calculators/tensorflow/BUILD | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index e7cc9cc94..0b30689eb 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -399,7 +399,7 @@ cc_library( # On android, this calculator is configured to run with lite protos. Therefore, # compile your binary with the flag TENSORFLOW_PROTOS=lite. cc_library( - name = "tensorflow_inference_calculator", + name = "tensorflow_inference_calculator_no_envelope_loader", srcs = ["tensorflow_inference_calculator.cc"], deps = [ ":tensorflow_inference_calculator_cc_proto", @@ -432,6 +432,19 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tensorflow_inference_calculator", + deps = [ + ":tensorflow_inference_calculator_no_envelope_loader", + ] + select({ + # Since "select" has "exactly one match" rule, we will need default condition to avoid + # "no matching conditions" error. Since all necessary dependencies are specified in + # "tensorflow_inference_calculator_no_envelope_loader" dependency, it is empty here. + "//conditions:default": [], + }), + alwayslink = 1, +) + cc_library( name = "tensorflow_session", hdrs = [ From c31a4681e5aa833bd53b7001969905efa1d826fb Mon Sep 17 00:00:00 2001 From: Joe Fernandez Date: Mon, 3 Apr 2023 17:41:28 -0700 Subject: [PATCH 30/63] Fix left navigation mediapipe.dev docs legacy solutions web pages PiperOrigin-RevId: 521609493 --- docs/solutions/autoflip.md | 2 +- docs/solutions/box_tracking.md | 2 +- docs/solutions/face_detection.md | 2 +- docs/solutions/face_mesh.md | 2 +- docs/solutions/hair_segmentation.md | 2 +- docs/solutions/hands.md | 2 +- docs/solutions/holistic.md | 2 +- docs/solutions/instant_motion_tracking.md | 2 +- docs/solutions/iris.md | 2 +- docs/solutions/knift.md | 2 +- docs/solutions/media_sequence.md | 2 +- docs/solutions/models.md | 2 +- docs/solutions/object_detection.md | 2 +- docs/solutions/object_detection_saved_model.md | 2 +- docs/solutions/objectron.md | 2 +- docs/solutions/pose.md | 2 +- docs/solutions/pose_classification.md | 2 +- docs/solutions/selfie_segmentation.md | 2 +- docs/solutions/youtube_8m.md | 2 +- 19 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/solutions/autoflip.md b/docs/solutions/autoflip.md index 9c5ca8766..a9e1e7052 100644 --- a/docs/solutions/autoflip.md +++ b/docs/solutions/autoflip.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: AutoFlip (Saliency-aware Video Cropping) -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 14 --- diff --git a/docs/solutions/box_tracking.md b/docs/solutions/box_tracking.md index 944059b80..537916ac4 100644 --- a/docs/solutions/box_tracking.md +++ b/docs/solutions/box_tracking.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Box Tracking -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 10 --- diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 3966fab9d..f060d062c 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/face_detector/ title: Face Detection -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 1 --- diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index 406e405b5..ab34ba401 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/face_landmarker/ title: Face Mesh -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 2 --- diff --git a/docs/solutions/hair_segmentation.md b/docs/solutions/hair_segmentation.md index a59bb93b8..feb40f9c0 100644 --- a/docs/solutions/hair_segmentation.md +++ b/docs/solutions/hair_segmentation.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/image_segmenter/ title: Hair Segmentation -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 8 --- diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index 280677f0f..6cf2264ed 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/hand_landmarker title: Hands -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 4 --- diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 70e6b5aff..25288ab55 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -2,7 +2,7 @@ layout: forward target: https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md title: Holistic -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 6 --- diff --git a/docs/solutions/instant_motion_tracking.md b/docs/solutions/instant_motion_tracking.md index 76e36d12e..361bc91ff 100644 --- a/docs/solutions/instant_motion_tracking.md +++ b/docs/solutions/instant_motion_tracking.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Instant Motion Tracking -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 11 --- diff --git a/docs/solutions/iris.md b/docs/solutions/iris.md index 1f5486afd..eab3dedf6 100644 --- a/docs/solutions/iris.md +++ b/docs/solutions/iris.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/face_landmarker/ title: Iris -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 3 --- diff --git a/docs/solutions/knift.md b/docs/solutions/knift.md index 50bc1df62..19e04cb5e 100644 --- a/docs/solutions/knift.md +++ b/docs/solutions/knift.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: KNIFT (Template-based Feature Matching) -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 13 --- diff --git a/docs/solutions/media_sequence.md b/docs/solutions/media_sequence.md index 769292a76..5224dd371 100644 --- a/docs/solutions/media_sequence.md +++ b/docs/solutions/media_sequence.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Dataset Preparation with MediaSequence -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 15 --- diff --git a/docs/solutions/models.md b/docs/solutions/models.md index d0aae8e77..0af91eb48 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Models and Model Cards -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 30 --- diff --git a/docs/solutions/object_detection.md b/docs/solutions/object_detection.md index 22eaed563..efa2e5266 100644 --- a/docs/solutions/object_detection.md +++ b/docs/solutions/object_detection.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/object_detector/ title: Object Detection -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 9 --- diff --git a/docs/solutions/object_detection_saved_model.md b/docs/solutions/object_detection_saved_model.md index 262bacbab..1c67bca68 100644 --- a/docs/solutions/object_detection_saved_model.md +++ b/docs/solutions/object_detection_saved_model.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/object_detector title: Object Detection -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 9 --- diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 4eee7c31d..09f8028bc 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Objectron (3D Object Detection) -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 12 --- diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index fe9ffc0dc..3c9f14f54 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/ title: Pose -parent: Solutions +parent: MediaPipe Legacy Solutions has_children: true has_toc: false nav_order: 5 diff --git a/docs/solutions/pose_classification.md b/docs/solutions/pose_classification.md index d2c4d5575..8420e2d7c 100644 --- a/docs/solutions/pose_classification.md +++ b/docs/solutions/pose_classification.md @@ -3,7 +3,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/ title: Pose Classification parent: Pose -grand_parent: Solutions +grand_parent: MediaPipe Legacy Solutions nav_order: 1 --- diff --git a/docs/solutions/selfie_segmentation.md b/docs/solutions/selfie_segmentation.md index 54be628cd..17e6fc252 100644 --- a/docs/solutions/selfie_segmentation.md +++ b/docs/solutions/selfie_segmentation.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/vision/image_segmenter/ title: Selfie Segmentation -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 7 --- diff --git a/docs/solutions/youtube_8m.md b/docs/solutions/youtube_8m.md index 3a615a8a7..80fb9d9a6 100644 --- a/docs/solutions/youtube_8m.md +++ b/docs/solutions/youtube_8m.md @@ -2,7 +2,7 @@ layout: forward target: https://developers.google.com/mediapipe/solutions/guide#legacy title: YouTube-8M Feature Extraction and Model Inference -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 16 --- From 367ccbfdf3c59e9c4baa891dbdcd0c1d944d93f9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 00:23:03 -0700 Subject: [PATCH 31/63] update ImageSegmenterGraph to always output confidence mask and optionally output category mask PiperOrigin-RevId: 521679910 --- .../tasks/cc/vision/image_segmenter/BUILD | 9 ++ .../tensors_to_segmentation_calculator.cc | 96 +++++++++---- ...tensors_to_segmentation_calculator_test.cc | 86 +++++------- .../vision/image_segmenter/image_segmenter.cc | 80 ++++++----- .../vision/image_segmenter/image_segmenter.h | 60 +++------ .../image_segmenter/image_segmenter_graph.cc | 127 +++++++++++++----- .../image_segmenter/image_segmenter_result.h | 43 ++++++ .../image_segmenter/image_segmenter_test.cc | 124 ++++++++--------- .../proto/segmenter_options.proto | 2 +- 9 files changed, 370 insertions(+), 257 deletions(-) create mode 100644 mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 69833a5f6..ee1cd3693 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,6 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +cc_library( + name = "image_segmenter_result", + hdrs = ["image_segmenter_result.h"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework/formats:image"], +) + # Docs for Mediapipe Tasks Image Segmenter # https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( @@ -25,6 +32,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", + ":image_segmenter_result", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", @@ -82,6 +90,7 @@ cc_library( "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:graph_builder_utils", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 0cdc8fe0f..49ad18029 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -80,10 +80,10 @@ void Sigmoid(absl::Span values, [](float value) { return 1. / (1 + std::exp(-value)); }); } -std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, - const Shape& output_shape, - const SegmenterOptions& options, - const float* tensors_buffer) { +Image ProcessForCategoryMaskCpu(const Shape& input_shape, + const Shape& output_shape, + const SegmenterOptions& options, + const float* tensors_buffer) { cv::Mat resized_tensors_mat; cv::Mat tensors_mat_view( input_shape.height, input_shape.width, CV_32FC(input_shape.channels), @@ -135,7 +135,7 @@ std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, pixel = maximum_category_idx; } }); - return {category_mask}; + return category_mask; } std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, @@ -209,7 +209,9 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, } // namespace -// Converts Tensors from a vector of Tensor to Segmentation. +// Converts Tensors from a vector of Tensor to Segmentation masks. The +// calculator always output confidence masks, and an optional category mask if +// CATEGORY_MASK is connected. // // Performs optional resizing to OUTPUT_SIZE dimension if provided, // otherwise the segmented masks is the same size as input tensor. @@ -221,7 +223,12 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, // the size to resize masks to. // // Output: -// Segmentation: Segmentation proto. +// CONFIDENCE_MASK @Multiple: Multiple masks of float image where, for each +// mask, each pixel represents the prediction confidence, usually in the [0, +// 1] range. +// CATEGORY_MASK @Optional: A category mask of uint8 image where each pixel +// represents the class which the pixel in the original image was predicted to +// belong to. // // Options: // See tensors_to_segmentation_calculator.proto @@ -231,13 +238,13 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, // calculator: "TensorsToSegmentationCalculator" // input_stream: "TENSORS:tensors" // input_stream: "OUTPUT_SIZE:size" -// output_stream: "SEGMENTATION:0:segmentation" -// output_stream: "SEGMENTATION:1:segmentation" +// output_stream: "CONFIDENCE_MASK:0:confidence_mask" +// output_stream: "CONFIDENCE_MASK:1:confidence_mask" +// output_stream: "CATEGORY_MASK:category_mask" // options { // [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { // segmenter_options { // activation: SOFTMAX -// output_type: CONFIDENCE_MASK // } // } // } @@ -248,7 +255,11 @@ class TensorsToSegmentationCalculator : public Node { static constexpr Input>::Optional kOutputSizeIn{ "OUTPUT_SIZE"}; static constexpr Output::Multiple kSegmentationOut{"SEGMENTATION"}; - MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut); + static constexpr Output::Multiple kConfidenceMaskOut{ + "CONFIDENCE_MASK"}; + static constexpr Output::Optional kCategoryMaskOut{"CATEGORY_MASK"}; + MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut, + kConfidenceMaskOut, kCategoryMaskOut); static absl::Status UpdateContract(CalculatorContract* cc); @@ -279,9 +290,13 @@ absl::Status TensorsToSegmentationCalculator::UpdateContract( absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { options_ = cc->Options(); - RET_CHECK_NE(options_.segmenter_options().output_type(), - SegmenterOptions::UNSPECIFIED) - << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; + // TODO: remove deprecated output type support. + if (options_.segmenter_options().has_output_type()) { + RET_CHECK_NE(options_.segmenter_options().output_type(), + SegmenterOptions::UNSPECIFIED) + << "Must specify output_type as one of " + "[CONFIDENCE_MASK|CATEGORY_MASK]."; + } #ifdef __EMSCRIPTEN__ MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); #endif // __EMSCRIPTEN__ @@ -309,6 +324,10 @@ absl::Status TensorsToSegmentationCalculator::Process( if (cc->Inputs().HasTag("OUTPUT_SIZE")) { std::tie(output_width, output_height) = kOutputSizeIn(cc).Get(); } + + // Use GPU postprocessing on web when Tensor is there already and has <= 12 + // categories. +#ifdef __EMSCRIPTEN__ Shape output_shape = { /* height= */ output_height, /* width= */ output_width, @@ -316,10 +335,6 @@ absl::Status TensorsToSegmentationCalculator::Process( SegmenterOptions::CATEGORY_MASK ? 1 : input_shape.channels}; - - // Use GPU postprocessing on web when Tensor is there already and has <= 12 - // categories. -#ifdef __EMSCRIPTEN__ if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) { std::vector> segmented_masks = postprocessor_.GetSegmentationResultGpu(input_shape, output_shape, @@ -332,10 +347,41 @@ absl::Status TensorsToSegmentationCalculator::Process( #endif // __EMSCRIPTEN__ // Otherwise, use CPU postprocessing. - std::vector segmented_masks = GetSegmentationResultCpu( - input_shape, output_shape, input_tensor.GetCpuReadView().buffer()); - for (int i = 0; i < segmented_masks.size(); ++i) { - kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i])); + const float* tensors_buffer = input_tensor.GetCpuReadView().buffer(); + + // TODO: remove deprecated output type support. + if (options_.segmenter_options().has_output_type()) { + std::vector segmented_masks = GetSegmentationResultCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ options_.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK + ? 1 + : input_shape.channels}, + input_tensor.GetCpuReadView().buffer()); + for (int i = 0; i < segmented_masks.size(); ++i) { + kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i])); + } + return absl::OkStatus(); + } + + std::vector confidence_masks = + ProcessForConfidenceMaskCpu(input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ input_shape.channels}, + options_.segmenter_options(), tensors_buffer); + for (int i = 0; i < confidence_masks.size(); ++i) { + kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); + } + if (cc->Outputs().HasTag("CATEGORY_MASK")) { + kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ 1}, + options_.segmenter_options(), tensors_buffer)); } return absl::OkStatus(); } @@ -345,9 +391,9 @@ std::vector TensorsToSegmentationCalculator::GetSegmentationResultCpu( const float* tensors_buffer) { if (options_.segmenter_options().output_type() == SegmenterOptions::CATEGORY_MASK) { - return ProcessForCategoryMaskCpu(input_shape, output_shape, - options_.segmenter_options(), - tensors_buffer); + return {ProcessForCategoryMaskCpu(input_shape, output_shape, + options_.segmenter_options(), + tensors_buffer)}; } else { return ProcessForConfidenceMaskCpu(input_shape, output_shape, options_.segmenter_options(), diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 54fb9b816..d6a2f3fd9 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -79,8 +79,9 @@ void PushTensorsToRunner(int tensor_height, int tensor_width, std::vector GetPackets(const CalculatorRunner& runner) { std::vector mask_packets; for (int i = 0; i < runner.Outputs().NumEntries(); ++i) { - EXPECT_EQ(runner.Outputs().Get("SEGMENTATION", i).packets.size(), 1); - mask_packets.push_back(runner.Outputs().Get("SEGMENTATION", i).packets[0]); + EXPECT_EQ(runner.Outputs().Get("CONFIDENCE_MASK", i).packets.size(), 1); + mask_packets.push_back( + runner.Outputs().Get("CONFIDENCE_MASK", i).packets[0]); } return mask_packets; } @@ -118,13 +119,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -145,13 +143,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -173,16 +168,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -218,16 +210,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -259,16 +248,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SIGMOID - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SIGMOID } } } )pb")); @@ -301,13 +287,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" + output_stream: "CATEGORY_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CATEGORY_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -318,11 +305,11 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { tensor_height, tensor_width, std::vector(kTestValues.begin(), kTestValues.end()), &runner); MP_ASSERT_OK(runner.Run()); - ASSERT_EQ(runner.Outputs().NumEntries(), 1); + ASSERT_EQ(runner.Outputs().NumEntries(), 5); // Largest element index is 3. const int expected_index = 3; const std::vector buffer_indices = {0}; - std::vector packets = GetPackets(runner); + std::vector packets = runner.Outputs().Tag("CATEGORY_MASK").packets; EXPECT_THAT(packets, testing::ElementsAre( Uint8ImagePacket(tensor_height, tensor_width, expected_index, buffer_indices))); @@ -335,13 +322,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" input_stream: "OUTPUT_SIZE:size" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" + output_stream: "CATEGORY_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CATEGORY_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -367,7 +355,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { const std::vector buffer_indices = { 0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0, 1 * output_width + 1}; - std::vector packets = GetPackets(runner); + std::vector packets = runner.Outputs().Tag("CATEGORY_MASK").packets; EXPECT_THAT(packets, testing::ElementsAre( Uint8ImagePacket(output_height, output_width, expected_index, buffer_indices))); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index ab1d3c84b..8f03ff086 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -37,8 +37,10 @@ namespace vision { namespace image_segmenter { namespace { -constexpr char kSegmentationStreamName[] = "segmented_mask_out"; -constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; +constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; +constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; +constexpr char kCategoryMaskStreamName[] = "category_mask"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; @@ -51,7 +53,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; -using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; @@ -59,21 +60,24 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, - bool enable_flow_limiting) { + bool output_category_mask, bool enable_flow_limiting) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> - graph.Out(kGroupedSegmentationTag); + task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >> + graph.Out(kConfidenceMasksTag); + if (output_category_mask) { + task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> + graph.Out(kCategoryMaskTag); + } task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, - {kImageTag, kNormRectTag}, - kGroupedSegmentationTag); + return tasks::core::AddFlowLimiterCalculator( + graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); @@ -91,16 +95,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); options_proto->set_display_names_locale(options->display_names_locale); - switch (options->output_type) { - case ImageSegmenterOptions::OutputType::CATEGORY_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CATEGORY_MASK); - break; - case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - break; - } return options_proto; } @@ -145,6 +139,7 @@ absl::StatusOr> ImageSegmenter::Create( tasks::core::PacketsCallback packets_callback = nullptr; if (options->result_callback) { auto result_callback = options->result_callback; + bool output_category_mask = options->output_category_mask; packets_callback = [=](absl::StatusOr status_or_packets) { if (!status_or_packets.ok()) { @@ -156,34 +151,41 @@ absl::StatusOr> ImageSegmenter::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } - Packet segmented_masks = - status_or_packets.value()[kSegmentationStreamName]; + Packet confidence_masks = + status_or_packets.value()[kConfidenceMasksStreamName]; + std::optional category_mask; + if (output_category_mask) { + category_mask = + status_or_packets.value()[kCategoryMaskStreamName].Get(); + } Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(segmented_masks.Get>(), - image_packet.Get(), - segmented_masks.Timestamp().Value() / - kMicroSecondsPerMilliSecond); + result_callback( + {{confidence_masks.Get>(), category_mask}}, + image_packet.Get(), + confidence_masks.Timestamp().Value() / + kMicroSecondsPerMilliSecond); }; } - auto image_segmenter = core::VisionTaskApiFactory::Create( CreateGraphConfig( - std::move(options_proto), + std::move(options_proto), options->output_category_mask, options->running_mode == core::RunningMode::LIVE_STREAM), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)); if (!image_segmenter.ok()) { return image_segmenter.status(); } + image_segmenter.value()->output_category_mask_ = + options->output_category_mask; ASSIGN_OR_RETURN( (*image_segmenter)->labels_, GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig())); return image_segmenter; } -absl::StatusOr> ImageSegmenter::Segment( +absl::StatusOr ImageSegmenter::Segment( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -201,11 +203,17 @@ absl::StatusOr> ImageSegmenter::Segment( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, MakePacket(std::move(norm_rect))}})); - return output_packets[kSegmentationStreamName].Get>(); + std::vector confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + std::optional category_mask; + if (output_category_mask_) { + category_mask = output_packets[kCategoryMaskStreamName].Get(); + } + return {{confidence_masks, category_mask}}; } -absl::StatusOr> ImageSegmenter::SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms, +absl::StatusOr ImageSegmenter::SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( @@ -225,11 +233,17 @@ absl::StatusOr> ImageSegmenter::SegmentForVideo( {kNormRectStreamName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kSegmentationStreamName].Get>(); + std::vector confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + std::optional category_mask; + if (output_category_mask_) { + category_mask = output_packets[kCategoryMaskStreamName].Get(); + } + return {{confidence_masks, category_mask}}; } absl::Status ImageSegmenter::SegmentAsync( - Image image, int64 timestamp_ms, + Image image, int64_t timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 076a5016c..1d18e3903 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -26,6 +26,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { @@ -52,23 +53,14 @@ struct ImageSegmenterOptions { // Metadata, if any. Defaults to English. std::string display_names_locale = "en"; - // The output type of segmentation results. - enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK = 0, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK = 1, - }; - - OutputType output_type = OutputType::CATEGORY_MASK; + // Whether to output category mask. + bool output_category_mask = false; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, + int64_t)> result_callback = nullptr; }; @@ -84,13 +76,9 @@ struct ImageSegmenterOptions { // 1 or 3). // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. -// Output tensors: -// (kTfLiteUInt8/kTfLiteFloat32) -// - list of segmented masks. -// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. -// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `channels`. -// - batch is always 1 +// Output ImageSegmenterResult: +// Provides confidence masks and an optional category mask if +// `output_category_mask` is set true. // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { @@ -114,12 +102,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - // - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. - absl::StatusOr> Segment( + + absl::StatusOr Segment( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -137,13 +121,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - // - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. - absl::StatusOr> SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms, + absl::StatusOr SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); @@ -164,17 +143,13 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // and will result in an invalid argument error being returned. // // The "result_callback" prvoides - // - A vector of segmented image masks. - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. + // - An ImageSegmenterResult. // - The const reference to the corresponding input image that the image // segmentation runs on. Note that the const reference to the image will // no longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms, + absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); @@ -182,9 +157,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { absl::Status Close() { return runner_->Close(); } // Get the category label list of the ImageSegmenter can recognize. For - // CATEGORY_MASK type, the index in the category mask corresponds to the - // category in the label list. For CONFIDENCE_MASK type, the output mask list - // at index corresponds to the category in the label list. + // CATEGORY_MASK, the index in the category mask corresponds to the category + // in the label list. For CONFIDENCE_MASK, the output mask list at index + // corresponds to the category in the label list. // // If there is no labelmap provided in the model file, empty label list is // returned. @@ -192,6 +167,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { private: std::vector labels_; + bool output_category_mask_; }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index fe6265b73..4b9e7618b 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" +#include "mediapipe/util/graph_builder_utils.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -65,10 +67,13 @@ using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::TensorMetadata; -using LabelItems = mediapipe::proto_ns::Map; +using LabelItems = mediapipe::proto_ns::Map; constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kConfidenceMaskTag[] = "CONFIDENCE_MASK"; +constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; +constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageCpuTag[] = "IMAGE_CPU"; constexpr char kImageGpuTag[] = "IMAGE_GPU"; @@ -80,7 +85,9 @@ constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter // subgraph. struct ImageSegmenterOutputs { - std::vector> segmented_masks; + std::optional>> segmented_masks; + std::optional>> confidence_masks; + std::optional> category_mask; // The same as the input image, mainly used for live stream mode. Source image; }; @@ -95,8 +102,10 @@ struct ImageAndTensorsOnDevice { } // namespace absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) { - if (options.segmenter_options().output_type() == - SegmenterOptions::UNSPECIFIED) { + // TODO: remove deprecated output type support. + if (options.segmenter_options().has_output_type() && + options.segmenter_options().output_type() == + SegmenterOptions::UNSPECIFIED) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, "`output_type` must not be UNSPECIFIED", MediaPipeTasksStatus::kInvalidArgumentError); @@ -133,9 +142,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator( const core::ModelResources& model_resources, TensorsToSegmentationCalculatorOptions* options) { // Set default activation function NONE - options->mutable_segmenter_options()->set_output_type( - segmenter_option.segmenter_options().output_type()); - options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE); + options->mutable_segmenter_options()->CopyFrom( + segmenter_option.segmenter_options()); // Find the custom metadata of ImageSegmenterOptions type in model metadata. const auto* metadata_extractor = model_resources.GetMetadataExtractor(); bool found_activation_in_metadata = false; @@ -317,12 +325,14 @@ absl::StatusOr ConvertImageToTensors( } } -// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic -// segmentation. -// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. -// Users can retrieve segmented mask of only particular category/channel from -// SEGMENTATION, and users can also get all segmented masks from -// GROUPED_SEGMENTATION. +// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs +// semantic segmentation. The graph always output confidence masks, and an +// optional category mask if CATEGORY_MASK is connected. +// +// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and +// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular +// category/channel from CONFIDENCE_MASK, and users can also get all segmented +// confidence masks from CONFIDENCE_MASKS. // - Accepts CPU input images and outputs segmented masks on CPU. // // Inputs: @@ -334,11 +344,13 @@ absl::StatusOr ConvertImageToTensors( // @Optional: rect covering the whole image is used if not specified. // // Outputs: -// SEGMENTATION - mediapipe::Image @Multiple -// Segmented masks for individual category. Segmented mask of single +// CONFIDENCE_MASK - mediapipe::Image @Multiple +// Confidence masks for individual category. Confidence mask of single // category can be accessed by index based output stream. -// GROUPED_SEGMENTATION - std::vector -// The output segmented masks grouped in a vector. +// CONFIDENCE_MASKS - std::vector +// The output confidence masks grouped in a vector. +// CATEGORY_MASK - mediapipe::Image @Optional +// Optional Category mask. // IMAGE - mediapipe::Image // The image that image segmenter runs on. // @@ -369,23 +381,39 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; + const auto& options = sc->Options(); ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( - sc->Options(), *model_resources, - graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + options, *model_resources, graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], + HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph)); auto& merge_images_to_vector = graph.AddNode("MergeImagesToVectorCalculator"); - for (int i = 0; i < output_streams.segmented_masks.size(); ++i) { - output_streams.segmented_masks[i] >> - merge_images_to_vector[Input::Multiple("")][i]; - output_streams.segmented_masks[i] >> - graph[Output::Multiple(kSegmentationTag)][i]; + // TODO: remove deprecated output type support. + if (options.segmenter_options().has_output_type()) { + for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { + output_streams.segmented_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.segmented_masks->at(i) >> + graph[Output::Multiple(kSegmentationTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>(kGroupedSegmentationTag)]; + } else { + for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { + output_streams.confidence_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.confidence_masks->at(i) >> + graph[Output::Multiple(kConfidenceMaskTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>(kConfidenceMasksTag)]; + if (output_streams.category_mask) { + *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; + } } - merge_images_to_vector.Out("") >> - graph[Output>(kGroupedSegmentationTag)]; output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -403,7 +431,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Source norm_rect_in, Graph& graph) { + Source norm_rect_in, bool output_category_mask, + Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -435,22 +464,46 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); // Exports multiple segmented masks. - std::vector> segmented_masks; - if (task_options.segmenter_options().output_type() == - SegmenterOptions::CATEGORY_MASK) { - segmented_masks.push_back( - Source(tensor_to_images[Output(kSegmentationTag)])); + // TODO: remove deprecated output type support. + if (task_options.segmenter_options().has_output_type()) { + std::vector> segmented_masks; + if (task_options.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK) { + segmented_masks.push_back( + Source(tensor_to_images[Output(kSegmentationTag)])); + } else { + ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, + GetOutputTensor(model_resources)); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); + for (int i = 0; i < segmentation_streams_num; ++i) { + segmented_masks.push_back(Source( + tensor_to_images[Output::Multiple(kSegmentationTag)][i])); + } + } + return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, + /*confidence_masks=*/std::nullopt, + /*category_mask=*/std::nullopt, + /*image=*/image_and_tensors.image}; } else { ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, GetOutputTensor(model_resources)); int segmentation_streams_num = *output_tensor->shape()->rbegin(); + std::vector> confidence_masks; + confidence_masks.reserve(segmentation_streams_num); for (int i = 0; i < segmentation_streams_num; ++i) { - segmented_masks.push_back(Source( - tensor_to_images[Output::Multiple(kSegmentationTag)][i])); + confidence_masks.push_back(Source( + tensor_to_images[Output::Multiple(kConfidenceMaskTag)][i])); } + return ImageSegmenterOutputs{ + /*segmented_masks=*/std::nullopt, + /*confidence_masks=*/confidence_masks, + /*category_mask=*/ + output_category_mask + ? std::make_optional( + tensor_to_images[Output(kCategoryMaskTag)]) + : std::nullopt, + /*image=*/image_and_tensors.image}; } - return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, - /*image=*/image_and_tensors.image}; } }; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h new file mode 100644 index 000000000..fb2ec05f1 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ + +#include + +#include "mediapipe/framework/formats/image.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_segmenter { + +// The output result of ImageSegmenter +struct ImageSegmenterResult { + // Multiple masks of float image in VEC32F1 format where, for each mask, each + // pixel represents the prediction confidence, usually in the [0, 1] range. + std::vector confidence_masks; + // A category mask of uint8 image in GRAY8 format where each pixel represents + // the class which the pixel in the original image was predicted to belong to. + std::optional category_mask; +}; + +} // namespace image_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 1d75a3fb7..1e4387491 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -256,7 +257,6 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -278,15 +278,14 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image)); - EXPECT_EQ(category_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); cv::Mat actual_mask = mediapipe::formats::MatView( - category_masks[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), @@ -303,12 +302,11 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 21); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 21); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); @@ -317,7 +315,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks[8].GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -331,15 +329,14 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); ImageProcessingOptions image_processing_options; image_processing_options.rotation_degrees = -90; - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image, image_processing_options)); - EXPECT_EQ(confidence_masks.size(), 21); + EXPECT_EQ(result.confidence_masks.size(), 21); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), @@ -349,7 +346,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks[8].GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -361,7 +358,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -384,12 +380,11 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 2); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 2); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -400,7 +395,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { // Selfie category index 1. cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks[1].GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -411,11 +406,10 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 1); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -425,7 +419,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks[0].GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -436,12 +430,11 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 1); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( @@ -452,7 +445,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks[0].GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -463,16 +456,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); - EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); cv::Mat selfie_mask = mediapipe::formats::MatView( - category_mask[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_selfie_segmentation_expected_category_mask.jpg"), @@ -487,16 +479,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); - EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); cv::Mat selfie_mask = mediapipe::formats::MatView( - category_mask[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath( "./", kTestDataDirectory, @@ -512,14 +503,13 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 2); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 2); cv::Mat hair_mask = mediapipe::formats::MatView( - confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks[1].GetImageFrameSharedPtr().get()); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), @@ -540,7 +530,6 @@ TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::VIDEO; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, @@ -572,7 +561,7 @@ TEST_F(VideoModeTest, Succeeds) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->output_category_mask = true; options->running_mode = core::RunningMode::VIDEO; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -580,11 +569,10 @@ TEST_F(VideoModeTest, Succeeds) { JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), cv::IMREAD_GRAYSCALE); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK_AND_ASSIGN(auto category_masks, - segmenter->SegmentForVideo(image, i)); - EXPECT_EQ(category_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->SegmentForVideo(image, i)); + EXPECT_TRUE(result.category_mask.has_value()); cv::Mat actual_mask = mediapipe::formats::MatView( - category_masks[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, kGoldenMaskMagnificationFactor)); @@ -601,11 +589,10 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::LIVE_STREAM; options->result_callback = - [](absl::StatusOr> segmented_masks, const Image& image, - int64 timestamp_ms) {}; + [](absl::StatusOr segmented_masks, + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -634,11 +621,9 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> segmented_masks, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr result, + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK(segmenter->SegmentAsync(image, 1)); @@ -660,23 +645,23 @@ TEST_F(LiveStreamModeTest, Succeeds) { Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "segmentation_input_rotation0.jpg"))); - std::vector> segmented_masks_results; + std::vector segmented_masks_results; std::vector> image_sizes; - std::vector timestamps; + std::vector timestamps; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->output_category_mask = true; options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [&segmented_masks_results, &image_sizes, ×tamps]( - absl::StatusOr> segmented_masks, - const Image& image, int64 timestamp_ms) { - MP_ASSERT_OK(segmented_masks.status()); - segmented_masks_results.push_back(std::move(segmented_masks).value()); - image_sizes.push_back({image.width(), image.height()}); - timestamps.push_back(timestamp_ms); - }; + options->result_callback = [&segmented_masks_results, &image_sizes, + ×tamps]( + absl::StatusOr result, + const Image& image, int64_t timestamp_ms) { + MP_ASSERT_OK(result.status()); + segmented_masks_results.push_back(std::move(*result->category_mask)); + image_sizes.push_back({image.width(), image.height()}); + timestamps.push_back(timestamp_ms); + }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); for (int i = 0; i < iterations; ++i) { @@ -690,10 +675,9 @@ TEST_F(LiveStreamModeTest, Succeeds) { cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), cv::IMREAD_GRAYSCALE); - for (const auto& segmented_masks : segmented_masks_results) { - EXPECT_EQ(segmented_masks.size(), 1); + for (const auto& category_mask : segmented_masks_results) { cv::Mat actual_mask = mediapipe::formats::MatView( - segmented_masks[0].GetImageFrameSharedPtr().get()); + category_mask.GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, kGoldenMaskMagnificationFactor)); @@ -702,7 +686,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.second, image.height()); } - int64 timestamp_ms = -1; + int64_t timestamp_ms = -1; for (const auto& timestamp : timestamps) { EXPECT_GT(timestamp, timestamp_ms); timestamp_ms = timestamp; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index be2b8a589..b1ec529d0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -33,7 +33,7 @@ message SegmenterOptions { CONFIDENCE_MASK = 2; } // Optional output mask type. - optional OutputType output_type = 1 [default = CATEGORY_MASK]; + optional OutputType output_type = 1 [deprecated = true]; // Supported activation functions for filtering. enum Activation { From 53fa35e40c00f2e58b6ef75d75fa6c0ee15e4c09 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 01:16:50 -0700 Subject: [PATCH 32/63] Add FrameBuffer view on ImageFrame. PiperOrigin-RevId: 521689386 --- mediapipe/gpu/BUILD | 3 + mediapipe/gpu/frame_buffer_view.h | 37 +++ .../gpu/gpu_buffer_storage_image_frame.cc | 71 ++++++ .../gpu/gpu_buffer_storage_image_frame.h | 24 +- mediapipe/gpu/gpu_buffer_storage_yuv_image.cc | 228 ++++++++++++++++++ mediapipe/gpu/gpu_buffer_storage_yuv_image.h | 84 +++++++ 6 files changed, 446 insertions(+), 1 deletion(-) create mode 100644 mediapipe/gpu/frame_buffer_view.h create mode 100644 mediapipe/gpu/gpu_buffer_storage_image_frame.cc create mode 100644 mediapipe/gpu/gpu_buffer_storage_yuv_image.cc create mode 100644 mediapipe/gpu/gpu_buffer_storage_yuv_image.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index ca2912ac3..c785e5624 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -423,12 +423,15 @@ cc_library( cc_library( name = "gpu_buffer_storage_image_frame", + srcs = ["gpu_buffer_storage_image_frame.cc"], hdrs = ["gpu_buffer_storage_image_frame.h"], visibility = ["//visibility:public"], deps = [ + ":frame_buffer_view", ":gpu_buffer_format", ":gpu_buffer_storage", ":image_frame_view", + "//mediapipe/framework/formats:frame_buffer", "//mediapipe/framework/formats:image_frame", ], ) diff --git a/mediapipe/gpu/frame_buffer_view.h b/mediapipe/gpu/frame_buffer_view.h new file mode 100644 index 000000000..76d773a5e --- /dev/null +++ b/mediapipe/gpu/frame_buffer_view.h @@ -0,0 +1,37 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_GPU_FRAME_BUFFER_VIEW_H_ +#define MEDIAPIPE_GPU_FRAME_BUFFER_VIEW_H_ + +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/gpu/gpu_buffer_storage.h" + +namespace mediapipe { +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual std::shared_ptr GetReadView( + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; +}; + +} // namespace internal +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_FRAME_BUFFER_VIEW_H_ diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.cc b/mediapipe/gpu/gpu_buffer_storage_image_frame.cc new file mode 100644 index 000000000..1cd661d37 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.cc @@ -0,0 +1,71 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" + +#include +#include + +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/formats/image_frame.h" + +namespace mediapipe { + +namespace { + +FrameBuffer::Format FrameBufferFormatForImageFrameFormat( + ImageFormat::Format format) { + switch (format) { + case ImageFormat::SRGB: + return FrameBuffer::Format::kRGB; + case ImageFormat::SRGBA: + return FrameBuffer::Format::kRGBA; + case ImageFormat::GRAY8: + return FrameBuffer::Format::kGRAY; + default: + return FrameBuffer::Format::kUNKNOWN; + } +} + +std::shared_ptr ImageFrameToFrameBuffer( + std::shared_ptr image_frame) { + FrameBuffer::Format format = + FrameBufferFormatForImageFrameFormat(image_frame->Format()); + CHECK(format != FrameBuffer::Format::kUNKNOWN) + << "Invalid format. Only SRGB, SRGBA and GRAY8 are supported."; + const FrameBuffer::Dimension dimension{/*width=*/image_frame->Width(), + /*height=*/image_frame->Height()}; + const FrameBuffer::Stride stride{ + /*row_stride_bytes=*/image_frame->WidthStep(), + /*pixel_stride_bytes=*/image_frame->ByteDepth() * + image_frame->NumberOfChannels()}; + const std::vector planes{ + {image_frame->MutablePixelData(), stride}}; + return std::make_shared(planes, dimension, format); +} + +} // namespace + +std::shared_ptr GpuBufferStorageImageFrame::GetReadView( + internal::types) const { + return ImageFrameToFrameBuffer(image_frame_); +} + +std::shared_ptr GpuBufferStorageImageFrame::GetWriteView( + internal::types) { + return ImageFrameToFrameBuffer(image_frame_); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h index ab547b9ea..542791f98 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.h +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -1,9 +1,26 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_IMAGE_FRAME_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_IMAGE_FRAME_H_ #include +#include "mediapipe/framework/formats/frame_buffer.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/frame_buffer_view.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" #include "mediapipe/gpu/image_frame_view.h" @@ -13,7 +30,8 @@ namespace mediapipe { // Implements support for ImageFrame as a backing storage of GpuBuffer. class GpuBufferStorageImageFrame : public internal::GpuBufferStorageImpl< - GpuBufferStorageImageFrame, internal::ViewProvider> { + GpuBufferStorageImageFrame, internal::ViewProvider, + internal::ViewProvider> { public: explicit GpuBufferStorageImageFrame(std::shared_ptr image_frame) : image_frame_(image_frame) {} @@ -36,6 +54,10 @@ class GpuBufferStorageImageFrame internal::types) override { return image_frame_; } + std::shared_ptr GetReadView( + internal::types) const override; + std::shared_ptr GetWriteView( + internal::types) override; private: std::shared_ptr image_frame_; diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc new file mode 100644 index 000000000..c7acd1340 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc @@ -0,0 +1,228 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/gpu/gpu_buffer_storage_yuv_image.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "libyuv/video_common.h" +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/yuv_image.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/util/frame_buffer/frame_buffer_util.h" + +namespace mediapipe { + +namespace { + +// Default data alignment. +constexpr int kDefaultDataAligment = 16; + +GpuBufferFormat GpuBufferFormatForFourCC(libyuv::FourCC fourcc) { + switch (fourcc) { + case libyuv::FOURCC_NV12: + return GpuBufferFormat::kNV12; + case libyuv::FOURCC_NV21: + return GpuBufferFormat::kNV21; + case libyuv::FOURCC_YV12: + return GpuBufferFormat::kYV12; + case libyuv::FOURCC_I420: + return GpuBufferFormat::kI420; + default: + return GpuBufferFormat::kUnknown; + } +} + +libyuv::FourCC FourCCForGpuBufferFormat(GpuBufferFormat format) { + switch (format) { + case GpuBufferFormat::kNV12: + return libyuv::FOURCC_NV12; + case GpuBufferFormat::kNV21: + return libyuv::FOURCC_NV21; + case GpuBufferFormat::kYV12: + return libyuv::FOURCC_YV12; + case GpuBufferFormat::kI420: + return libyuv::FOURCC_I420; + default: + return libyuv::FOURCC_ANY; + } +} + +FrameBuffer::Format FrameBufferFormatForFourCC(libyuv::FourCC fourcc) { + switch (fourcc) { + case libyuv::FOURCC_NV12: + return FrameBuffer::Format::kNV12; + case libyuv::FOURCC_NV21: + return FrameBuffer::Format::kNV21; + case libyuv::FOURCC_YV12: + return FrameBuffer::Format::kYV12; + case libyuv::FOURCC_I420: + return FrameBuffer::Format::kYV21; + default: + return FrameBuffer::Format::kUNKNOWN; + } +} + +// Converts a YuvImage into a FrameBuffer that shares the same data buffers. +std::shared_ptr YuvImageToFrameBuffer( + std::shared_ptr yuv_image) { + FrameBuffer::Format format = FrameBufferFormatForFourCC(yuv_image->fourcc()); + FrameBuffer::Dimension dimension{/*width=*/yuv_image->width(), + /*height=*/yuv_image->height()}; + std::vector planes; + CHECK(yuv_image->mutable_data(0) != nullptr && yuv_image->stride(0) > 0) + << "Invalid YuvImage. Expected plane at index 0 to be non-null and have " + "stride > 0."; + planes.emplace_back( + yuv_image->mutable_data(0), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(0), + /*pixel_stride_bytes=*/1}); + switch (format) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: { + CHECK(yuv_image->mutable_data(1) != nullptr && yuv_image->stride(1) > 0) + << "Invalid YuvImage. Expected plane at index 1 to be non-null and " + "have stride > 0."; + planes.emplace_back( + yuv_image->mutable_data(1), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(1), + /*pixel_stride_bytes=*/2}); + break; + } + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: { + CHECK(yuv_image->mutable_data(1) != nullptr && yuv_image->stride(1) > 0 && + yuv_image->mutable_data(2) != nullptr && yuv_image->stride(2) > 0) + << "Invalid YuvImage. Expected planes at indices 1 and 2 to be " + "non-null and have stride > 0."; + planes.emplace_back( + yuv_image->mutable_data(1), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(1), + /*pixel_stride_bytes=*/1}); + planes.emplace_back( + yuv_image->mutable_data(2), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(2), + /*pixel_stride_bytes=*/1}); + break; + } + default: + LOG(FATAL) + << "Invalid format. Only FOURCC_NV12, FOURCC_NV21, FOURCC_YV12 and " + "FOURCC_I420 are supported."; + } + return std::make_shared(planes, dimension, format); +} + +// Converts a YUVImage into an ImageFrame with ImageFormat::SRGB format. +// Note that this requires YUV -> RGB conversion. +std::shared_ptr YuvImageToImageFrame( + std::shared_ptr yuv_image) { + auto yuv_buffer = YuvImageToFrameBuffer(yuv_image); + // Allocate the RGB ImageFrame to return. + auto image_frame = std::make_shared( + ImageFormat::SRGB, yuv_buffer->dimension().width, + yuv_buffer->dimension().height); + // Wrap it into a FrameBuffer + std::vector planes{ + {image_frame->MutablePixelData(), + {/*row_stride_bytes=*/image_frame->WidthStep(), + /*pixel_stride_bytes=*/image_frame->NumberOfChannels() * + image_frame->ChannelSize()}}}; + auto rgb_buffer = + FrameBuffer(planes, yuv_buffer->dimension(), FrameBuffer::Format::kRGB); + // Convert. + CHECK_OK(frame_buffer::Convert(*yuv_buffer, &rgb_buffer)); + return image_frame; +} + +} // namespace + +GpuBufferStorageYuvImage::GpuBufferStorageYuvImage( + std::shared_ptr yuv_image) { + CHECK(GpuBufferFormatForFourCC(yuv_image->fourcc()) != + GpuBufferFormat::kUnknown) + << "Invalid format. Only FOURCC_NV12, FOURCC_NV21, FOURCC_YV12 and " + "FOURCC_I420 are supported."; + yuv_image_ = yuv_image; +} + +GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, + GpuBufferFormat format) { + libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format); + int y_stride = std::ceil(1.0f * width / kDefaultDataAligment); + auto y_data = std::make_unique(y_stride * height); + switch (fourcc) { + case libyuv::FOURCC_NV12: + case libyuv::FOURCC_NV21: { + // Interleaved U/V planes, 2x2 downsampling. + int uv_width = 2 * std::ceil(0.5f * width); + int uv_height = std::ceil(0.5f * height); + int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); + auto uv_data = std::make_unique(uv_stride * uv_height); + yuv_image_ = std::make_shared( + fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride, + nullptr, 0, width, height); + break; + } + case libyuv::FOURCC_YV12: + case libyuv::FOURCC_I420: { + // Non-interleaved U/V planes, 2x2 downsampling. + int uv_width = std::ceil(0.5f * width); + int uv_height = std::ceil(0.5f * height); + int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); + auto u_data = std::make_unique(uv_stride * uv_height); + auto v_data = std::make_unique(uv_stride * uv_height); + yuv_image_ = std::make_shared( + fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride, + std::move(v_data), uv_stride, width, height); + break; + } + default: + LOG(FATAL) + << "Invalid format. Only kNV12, kNV21, kYV12 and kYV21 are supported"; + } +} + +GpuBufferFormat GpuBufferStorageYuvImage::format() const { + return GpuBufferFormatForFourCC(yuv_image_->fourcc()); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetReadView( + internal::types) const { + return YuvImageToFrameBuffer(yuv_image_); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetWriteView( + internal::types) { + return YuvImageToFrameBuffer(yuv_image_); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetReadView( + internal::types) const { + return YuvImageToImageFrame(yuv_image_); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetWriteView( + internal::types) { + // Not supported on purpose: writes into the resulting ImageFrame cannot + // easily be ported back to the original YUV image. + LOG(FATAL) << "GetWriteView is not supported."; +} +} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.h b/mediapipe/gpu/gpu_buffer_storage_yuv_image.h new file mode 100644 index 000000000..6b34f4948 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.h @@ -0,0 +1,84 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/yuv_image.h" +#include "mediapipe/gpu/frame_buffer_view.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/gpu_buffer_storage.h" +#include "mediapipe/gpu/image_frame_view.h" + +#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_YUV_IMAGE_H_ +#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_YUV_IMAGE_H_ + +namespace mediapipe { + +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual std::shared_ptr GetReadView( + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; +}; + +} // namespace internal + +// TODO: add support for I444. +class GpuBufferStorageYuvImage + : public internal::GpuBufferStorageImpl< + GpuBufferStorageYuvImage, internal::ViewProvider, + internal::ViewProvider, + internal::ViewProvider> { + public: + // Constructor from an existing YUVImage with FOURCC_NV12, FOURCC_NV21, + // FOURCC_YV12 or FOURCC_I420 format. + explicit GpuBufferStorageYuvImage(std::shared_ptr yuv_image); + // Constructor. Supported formats are kNV12, kNV21, kYV12 and kI420. + // Stride is set by default so that row boundaries align to 16 bytes. + GpuBufferStorageYuvImage(int width, int height, GpuBufferFormat format); + + int width() const override { return yuv_image_->width(); } + int height() const override { return yuv_image_->height(); } + GpuBufferFormat format() const override; + + std::shared_ptr GetReadView( + internal::types) const override { + return yuv_image_; + } + std::shared_ptr GetWriteView(internal::types) override { + return yuv_image_; + } + + std::shared_ptr GetReadView( + internal::types) const override; + std::shared_ptr GetWriteView( + internal::types) override; + std::shared_ptr GetReadView( + internal::types) const override; + std::shared_ptr GetWriteView( + internal::types) override; + + private: + std::shared_ptr yuv_image_; +}; +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_YUV_IMAGE_H_ From e95f465d5804e7aa9d3c28b5d031f34e2b7a91e6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 03:42:46 -0700 Subject: [PATCH 33/63] Internal change PiperOrigin-RevId: 521716263 --- .../autoflip/quality/scene_camera_motion_analyzer.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc index 0bfe72548..96fc5f888 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc @@ -34,7 +34,7 @@ absl::Status SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, - const std::vector& scene_frame_timestamps, + const std::vector& scene_frame_timestamps, const bool has_solid_color_background, SceneKeyFrameCropSummary* scene_summary, std::vector* focus_point_frames, @@ -45,7 +45,7 @@ absl::Status SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( key_frame_crop_options, key_frame_crop_results, scene_frame_width, scene_frame_height, scene_summary)); - const int64 scene_span_ms = + const int64_t scene_span_ms = scene_frame_timestamps.empty() ? 0 : scene_frame_timestamps.back() - scene_frame_timestamps.front(); @@ -103,7 +103,7 @@ absl::Status SceneCameraMotionAnalyzer::ToUseSweepingMotion( absl::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( const KeyFrameCropOptions& key_frame_crop_options, - const double scene_span_sec, const int64 end_time_us, + const double scene_span_sec, const int64_t end_time_us, SceneKeyFrameCropSummary* scene_summary, SceneCameraMotion* scene_camera_motion) const { RET_CHECK_GE(scene_span_sec, 0.0) << "Scene time span is negative."; @@ -298,7 +298,7 @@ absl::Status SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( const SceneKeyFrameCropSummary& scene_summary, const SceneCameraMotion& scene_camera_motion, - const std::vector& scene_frame_timestamps, + const std::vector& scene_frame_timestamps, std::vector* focus_point_frames) const { RET_CHECK_NE(focus_point_frames, nullptr) << "Output vector of FocusPointFrame is null."; @@ -380,7 +380,7 @@ absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( const SceneKeyFrameCropSummary& scene_summary, const FocusPointFrameType focus_point_frame_type, - const std::vector& scene_frame_timestamps, + const std::vector& scene_frame_timestamps, std::vector* focus_point_frames) const { RET_CHECK_GE(scene_summary.key_frame_max_score(), 0.0) << "Maximum score is negative."; @@ -392,7 +392,7 @@ absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( const int scene_frame_height = scene_summary.scene_frame_height(); PiecewiseLinearFunction center_x_function, center_y_function, score_function; - const int64 timestamp_offset = key_frame_compact_infos[0].timestamp_ms(); + const int64_t timestamp_offset = key_frame_compact_infos[0].timestamp_ms(); for (int i = 0; i < num_key_frames; ++i) { const float center_x = key_frame_compact_infos[i].center_x(); const float center_y = key_frame_compact_infos[i].center_y(); From ec1d84aff7f6a4d0a28da6fca8d164a883021bab Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 03:57:32 -0700 Subject: [PATCH 34/63] Internal change PiperOrigin-RevId: 521718577 --- mediapipe/framework/tool/executor_util.cc | 2 +- .../framework/tool/options_field_util.cc | 30 +++++------ mediapipe/framework/tool/proto_util_lite.cc | 50 +++++++++---------- .../framework/tool/simulation_clock_test.cc | 14 +++--- .../framework/tool/switch_container_test.cc | 6 +-- mediapipe/framework/tool/test_util.cc | 16 +++--- mediapipe/framework/tool/validate_name.cc | 4 +- 7 files changed, 61 insertions(+), 61 deletions(-) diff --git a/mediapipe/framework/tool/executor_util.cc b/mediapipe/framework/tool/executor_util.cc index 91089cc71..6d967768e 100644 --- a/mediapipe/framework/tool/executor_util.cc +++ b/mediapipe/framework/tool/executor_util.cc @@ -22,7 +22,7 @@ namespace mediapipe { namespace tool { -void EnsureMinimumDefaultExecutorStackSize(const int32 min_stack_size, +void EnsureMinimumDefaultExecutorStackSize(const int32_t min_stack_size, CalculatorGraphConfig* config) { mediapipe::ExecutorConfig* default_executor_config = nullptr; for (mediapipe::ExecutorConfig& executor_config : diff --git a/mediapipe/framework/tool/options_field_util.cc b/mediapipe/framework/tool/options_field_util.cc index 483b023b9..308932d4f 100644 --- a/mediapipe/framework/tool/options_field_util.cc +++ b/mediapipe/framework/tool/options_field_util.cc @@ -487,24 +487,24 @@ FieldData AsFieldData(const proto_ns::MessageLite& message) { // Represents a protobuf enum value stored in a Packet. struct ProtoEnum { - ProtoEnum(int32 v) : value(v) {} - int32 value; + ProtoEnum(int32_t v) : value(v) {} + int32_t value; }; absl::StatusOr AsPacket(const FieldData& data) { Packet result; switch (data.value_case()) { case FieldData::ValueCase::kInt32Value: - result = MakePacket(data.int32_value()); + result = MakePacket(data.int32_value()); break; case FieldData::ValueCase::kInt64Value: - result = MakePacket(data.int64_value()); + result = MakePacket(data.int64_value()); break; case FieldData::ValueCase::kUint32Value: - result = MakePacket(data.uint32_value()); + result = MakePacket(data.uint32_value()); break; case FieldData::ValueCase::kUint64Value: - result = MakePacket(data.uint64_value()); + result = MakePacket(data.uint64_value()); break; case FieldData::ValueCase::kDoubleValue: result = MakePacket(data.double_value()); @@ -538,11 +538,11 @@ absl::StatusOr AsPacket(const FieldData& data) { } absl::StatusOr AsFieldData(Packet packet) { - static const auto* kTypeIds = new std::map{ - {kTypeId, WireFormatLite::CPPTYPE_INT32}, - {kTypeId, WireFormatLite::CPPTYPE_INT64}, - {kTypeId, WireFormatLite::CPPTYPE_UINT32}, - {kTypeId, WireFormatLite::CPPTYPE_UINT64}, + static const auto* kTypeIds = new std::map{ + {kTypeId, WireFormatLite::CPPTYPE_INT32}, + {kTypeId, WireFormatLite::CPPTYPE_INT64}, + {kTypeId, WireFormatLite::CPPTYPE_UINT32}, + {kTypeId, WireFormatLite::CPPTYPE_UINT64}, {kTypeId, WireFormatLite::CPPTYPE_DOUBLE}, {kTypeId, WireFormatLite::CPPTYPE_FLOAT}, {kTypeId, WireFormatLite::CPPTYPE_BOOL}, @@ -566,16 +566,16 @@ absl::StatusOr AsFieldData(Packet packet) { switch (kTypeIds->at(packet.GetTypeId())) { case WireFormatLite::CPPTYPE_INT32: - result.set_int32_value(packet.Get()); + result.set_int32_value(packet.Get()); break; case WireFormatLite::CPPTYPE_INT64: - result.set_int64_value(packet.Get()); + result.set_int64_value(packet.Get()); break; case WireFormatLite::CPPTYPE_UINT32: - result.set_uint32_value(packet.Get()); + result.set_uint32_value(packet.Get()); break; case WireFormatLite::CPPTYPE_UINT64: - result.set_uint64_value(packet.Get()); + result.set_uint64_value(packet.Get()); break; case WireFormatLite::CPPTYPE_DOUBLE: result.set_double_value(packet.Get()); diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index a810ce129..745f4a13b 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -48,11 +48,11 @@ bool IsLengthDelimited(WireFormatLite::WireType wire_type) { } // Reads a single data value for a wire type. -absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in, +absl::Status ReadFieldValue(uint32_t tag, CodedInputStream* in, std::string* result) { WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (IsLengthDelimited(wire_type)) { - uint32 length; + uint32_t length; RET_CHECK_NO_LOG(in->ReadVarint32(&length)); RET_CHECK_NO_LOG(in->ReadString(result, length)); } else { @@ -72,10 +72,10 @@ absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in, absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, CodedInputStream* in, std::vector* field_values) { - uint32 data_size; + uint32_t data_size; RET_CHECK_NO_LOG(in->ReadVarint32(&data_size)); // fake_tag encodes the wire-type for calls to WireFormatLite::SkipField. - uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type); + uint32_t fake_tag = WireFormatLite::MakeTag(1, wire_type); while (data_size > 0) { std::string number; MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number)); @@ -88,10 +88,10 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, +absl::Status GetFieldValues(uint32_t field_id, CodedInputStream* in, CodedOutputStream* out, std::vector* field_values) { - uint32 tag; + uint32_t tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); @@ -112,10 +112,10 @@ absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, } // Injects the data value(s) for one field into a serialized message. -void SetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, +void SetFieldValues(uint32_t field_id, WireFormatLite::WireType wire_type, const std::vector& field_values, CodedOutputStream* out) { - uint32 tag = WireFormatLite::MakeTag(field_id, wire_type); + uint32_t tag = WireFormatLite::MakeTag(field_id, wire_type); for (const std::string& field_value : field_values) { out->WriteVarint32(tag); if (IsLengthDelimited(wire_type)) { @@ -125,7 +125,7 @@ void SetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, } } -FieldAccess::FieldAccess(uint32 field_id, FieldType field_type) +FieldAccess::FieldAccess(uint32_t field_id, FieldType field_type) : field_id_(field_id), field_type_(field_type) {} absl::Status FieldAccess::SetMessage(const std::string& message) { @@ -397,11 +397,11 @@ static absl::Status DeserializeValue(const FieldValue& bytes, case W::TYPE_UINT64: return ReadPrimitive(&input, result); case W::TYPE_INT32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_FIXED64: return ReadPrimitive(&input, result); case W::TYPE_FIXED32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_BOOL: return ReadPrimitive(&input, result); case W::TYPE_BYTES: @@ -413,15 +413,15 @@ static absl::Status DeserializeValue(const FieldValue& bytes, case W::TYPE_MESSAGE: CHECK(false) << "DeserializeValue cannot deserialize a Message."; case W::TYPE_UINT32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_ENUM: return ReadPrimitive(&input, result); case W::TYPE_SFIXED32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_SFIXED64: return ReadPrimitive(&input, result); case W::TYPE_SINT32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_SINT64: return ReadPrimitive(&input, result); } @@ -523,27 +523,27 @@ absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, switch (field_type) { case WireFormatLite::TYPE_INT32: result->set_int32_value( - ReadValue(field_bytes, &status)); + ReadValue(field_bytes, &status)); break; case WireFormatLite::TYPE_SINT32: - result->set_int32_value( - ReadValue(field_bytes, &status)); + result->set_int32_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_INT64: result->set_int64_value( - ReadValue(field_bytes, &status)); + ReadValue(field_bytes, &status)); break; case WireFormatLite::TYPE_SINT64: - result->set_int64_value( - ReadValue(field_bytes, &status)); + result->set_int64_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_UINT32: - result->set_uint32_value( - ReadValue(field_bytes, &status)); + result->set_uint32_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_UINT64: - result->set_uint64_value( - ReadValue(field_bytes, &status)); + result->set_uint64_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_DOUBLE: result->set_double_value( @@ -559,7 +559,7 @@ absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, break; case WireFormatLite::TYPE_ENUM: result->set_enum_value( - ReadValue(field_bytes, &status)); + ReadValue(field_bytes, &status)); break; case WireFormatLite::TYPE_STRING: result->set_string_value(std::string(field_bytes)); diff --git a/mediapipe/framework/tool/simulation_clock_test.cc b/mediapipe/framework/tool/simulation_clock_test.cc index 3f2c3615c..c4c76e37e 100644 --- a/mediapipe/framework/tool/simulation_clock_test.cc +++ b/mediapipe/framework/tool/simulation_clock_test.cc @@ -99,17 +99,17 @@ class SimulationClockTest : public ::testing::Test { void SetupRealClock() { clock_ = mediapipe::Clock::RealClock(); } // Return the values of the timestamps of a vector of Packets. - static std::vector TimestampValues( + static std::vector TimestampValues( const std::vector& packets) { - std::vector result; + std::vector result; for (const Packet& p : packets) { result.push_back(p.Timestamp().Value()); } return result; } - static std::vector TimeValues(const std::vector& times) { - std::vector result; + static std::vector TimeValues(const std::vector& times) { + std::vector result; for (const absl::Time& t : times) { result.push_back(absl::ToUnixMicros(t)); } @@ -225,9 +225,9 @@ TEST_F(SimulationClockTest, InFlight) { // Add 10 input packets to the graph, one each 10 ms, starting after 11 ms // of clock time. Timestamps lag clock times by 1 ms. clock_->Sleep(absl::Microseconds(11000)); - for (uint64 ts = 10000; ts <= 100000; ts += 10000) { + for (uint64_t ts = 10000; ts <= 100000; ts += 10000) { MP_EXPECT_OK(graph_.AddPacketToInputStream( - "input_packets_0", MakePacket(ts).At(Timestamp(ts)))); + "input_packets_0", MakePacket(ts).At(Timestamp(ts)))); clock_->Sleep(absl::Microseconds(10000)); } @@ -266,7 +266,7 @@ TEST_F(SimulationClockTest, DestroyClock) { clock_->Sleep(absl::Microseconds(20000)); if (++input_count < 4) { outputs->Index(0).AddPacket( - MakePacket(input_count).At(Timestamp(input_count))); + MakePacket(input_count).At(Timestamp(input_count))); return absl::OkStatus(); } else { return tool::StatusStop(); diff --git a/mediapipe/framework/tool/switch_container_test.cc b/mediapipe/framework/tool/switch_container_test.cc index b20979b10..08cc4ab5a 100644 --- a/mediapipe/framework/tool/switch_container_test.cc +++ b/mediapipe/framework/tool/switch_container_test.cc @@ -144,7 +144,7 @@ void RunTestContainer(CalculatorGraphConfig supergraph, if (!send_bounds) { // Send enable == true signal at 5000 us. - const int64 enable_ts = 5000; + const int64_t enable_ts = 5000; MP_EXPECT_OK(graph.AddPacketToInputStream( "enable", MakePacket(true).At(Timestamp(enable_ts)))); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -152,7 +152,7 @@ void RunTestContainer(CalculatorGraphConfig supergraph, const int packet_count = 10; // Send int value packets at {10K, 20K, 30K, ..., 100K}. - for (uint64 t = 1; t <= packet_count; ++t) { + for (uint64_t t = 1; t <= packet_count; ++t) { if (send_bounds) { MP_EXPECT_OK(graph.AddPacketToInputStream( "enable", MakePacket(true).At(Timestamp(t * 10000)))); @@ -180,7 +180,7 @@ void RunTestContainer(CalculatorGraphConfig supergraph, } // Send int value packets at {110K, 120K, ..., 200K}. - for (uint64 t = 11; t <= packet_count * 2; ++t) { + for (uint64_t t = 11; t <= packet_count * 2; ++t) { if (send_bounds) { MP_EXPECT_OK(graph.AddPacketToInputStream( "enable", MakePacket(false).At(Timestamp(t * 10000)))); diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index e8b02084b..d05171d20 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -182,13 +182,13 @@ absl::Status CompareImageFrames(const ImageFrame& image1, case ImageFormat::SRGB: case ImageFormat::SRGBA: case ImageFormat::LAB8: - return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, - max_avg_diff, diff_image); + return CompareDiff(image1, image2, max_color_diff, + max_alpha_diff, max_avg_diff, diff_image); case ImageFormat::GRAY16: case ImageFormat::SRGB48: case ImageFormat::SRGBA64: - return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, - max_avg_diff, diff_image); + return CompareDiff(image1, image2, max_color_diff, + max_alpha_diff, max_avg_diff, diff_image); case ImageFormat::VEC32F1: case ImageFormat::VEC32F2: return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, @@ -350,17 +350,17 @@ std::unique_ptr GenerateLuminanceImage( auto luminance_image = absl::make_unique(original_image.Format(), width, height, ImageFrame::kGlDefaultAlignmentBoundary); - const uint8* pixel1 = original_image.PixelData(); - uint8* pixel2 = luminance_image->MutablePixelData(); + const uint8_t* pixel1 = original_image.PixelData(); + uint8_t* pixel2 = luminance_image->MutablePixelData(); const int width_padding1 = original_image.WidthStep() - width * channels; const int width_padding2 = luminance_image->WidthStep() - width * channels; for (int row = 0; row < height; ++row) { for (int col = 0; col < width; ++col) { float luminance = pixel1[0] * 0.2125f + pixel1[1] * 0.7154f + pixel1[2] * 0.0721f; - uint8 luminance_byte = 255; + uint8_t luminance_byte = 255; if (luminance < 255.0f) { - luminance_byte = static_cast(luminance); + luminance_byte = static_cast(luminance); } pixel2[0] = luminance_byte; pixel2[1] = luminance_byte; diff --git a/mediapipe/framework/tool/validate_name.cc b/mediapipe/framework/tool/validate_name.cc index bea857dd4..ad66b43d8 100644 --- a/mediapipe/framework/tool/validate_name.cc +++ b/mediapipe/framework/tool/validate_name.cc @@ -185,7 +185,7 @@ absl::Status ParseTagIndexName(const std::string& tag_index_name, tag_status = ValidateTag(v[0]); number_status = ValidateNumber(v[1]); if (number_status.ok()) { - int64 index64; + int64_t index64; RET_CHECK(absl::SimpleAtoi(v[1], &index64)); RET_CHECK_LE(index64, internal::kMaxCollectionItemId); the_index = index64; @@ -227,7 +227,7 @@ absl::Status ParseTagIndex(const std::string& tag_index, std::string* tag, } number_status = ValidateNumber(v[1]); if (number_status.ok()) { - int64 index64; + int64_t index64; RET_CHECK(absl::SimpleAtoi(v[1], &index64)); RET_CHECK_LE(index64, internal::kMaxCollectionItemId); the_index = index64; From 65a98be80967896e4f6152f715be8fc36cf71b89 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 08:37:33 -0700 Subject: [PATCH 35/63] Fixed comment and added note. PiperOrigin-RevId: 521772542 --- mediapipe/framework/tool/sink.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/framework/tool/sink.h b/mediapipe/framework/tool/sink.h index d659115ee..f786e60a7 100644 --- a/mediapipe/framework/tool/sink.h +++ b/mediapipe/framework/tool/sink.h @@ -62,9 +62,9 @@ namespace tool { // Example usage: // CalculatorGraphConfig config = tool::ParseGraphFromFileOrDie("config.txt"); // std::vector packet_dump; -// tool::AddVectorSink("output_samples", &config, &packet_dump, -// /*use_std_function=*/true); -// // Call tool::AddVectorSink() more times if you wish. +// tool::AddVectorSink("output_samples", &config, &packet_dump); +// // Call tool::AddVectorSink() more times if you wish. Note that each stream +// // needs to get its own packet vector. // CalculatorGraph graph; // CHECK_OK(graph.Initialize(config)); // // Set other input side packets. From 33cad24a5a960991e4ac7adada194f749af6dbd3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 10:39:35 -0700 Subject: [PATCH 36/63] Update java image segmenter to always output confidence masks and optionally output category mask. PiperOrigin-RevId: 521804641 --- .../vision/imagesegmenter/ImageSegmenter.java | 112 +++++++++--------- .../imagesegmenter/ImageSegmenterResult.java | 19 ++- .../InteractiveSegmenter.java | 2 + .../imagesegmenter/ImageSegmenterTest.java | 79 ++++++------ .../InteractiveSegmenterTest.java | 7 +- 5 files changed, 107 insertions(+), 112 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index b809ab963..e8e0e4051 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -79,15 +79,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi { private static final List INPUT_STREAMS = Collections.unmodifiableList( Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); - private static final List OUTPUT_STREAMS = - Collections.unmodifiableList( - Arrays.asList( - "GROUPED_SEGMENTATION:segmented_mask_out", - "IMAGE:image_out", - "SEGMENTATION:0:segmentation")); - private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; + private static final int CONFIDENCE_MASKS_OUT_STREAM_INDEX = 0; private static final int IMAGE_OUT_STREAM_INDEX = 1; - private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; + private static final int CONFIDENCE_MASK_OUT_STREAM_INDEX = 2; + private static final int CATEGORY_MASK_OUT_STREAM_INDEX = 3; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -104,6 +99,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public static ImageSegmenter createFromOptions( Context context, ImageSegmenterOptions segmenterOptions) { + List outputStreams = new ArrayList<>(); + outputStreams.add("CONFIDENCE_MASKS:confidence_masks"); + outputStreams.add("IMAGE:image_out"); + outputStreams.add("CONFIDENCE_MASK:0:confidence_mask"); + if (segmenterOptions.outputCategoryMask()) { + outputStreams.add("CATEGORY_MASK:category_mask"); + } // TODO: Consolidate OutputHandler and TaskRunner. OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( @@ -111,50 +113,62 @@ public final class ImageSegmenter extends BaseVisionTaskApi { @Override public ImageSegmenterResult convertToTaskResult(List packets) throws MediaPipeException { - if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { + if (packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); + Optional.empty(), + packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).getTimestamp()); } - List segmentedMasks = new ArrayList<>(); - int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int imageFormat = - segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK - ? MPImage.IMAGE_FORMAT_VEC32F1 - : MPImage.IMAGE_FORMAT_ALPHA; - int imageListSize = - PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); - ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + List confidenceMasks = new ArrayList<>(); + int width = PacketGetter.getImageWidth(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); + int height = PacketGetter.getImageHeight(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); + int confidenceMasksListSize = + PacketGetter.getImageListSize(packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX)); + ByteBuffer[] buffersArray = new ByteBuffer[confidenceMasksListSize]; // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe // graph. If provided, the result MPImage is wrapping the mediapipe packet memory. - if (!segmenterOptions.resultListener().isPresent()) { - for (int i = 0; i < imageListSize; i++) { - buffersArray[i] = - ByteBuffer.allocateDirect( - width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); + boolean copyImage = !segmenterOptions.resultListener().isPresent(); + if (copyImage) { + for (int i = 0; i < confidenceMasksListSize; i++) { + buffersArray[i] = ByteBuffer.allocateDirect(width * height * 4); } } if (!PacketGetter.getImageList( - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), - buffersArray, - !segmenterOptions.resultListener().isPresent())) { + packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX), buffersArray, copyImage)) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting segmented masks. It usually results from incorrect" - + " options of unsupported OutputType of given model."); + "There is an error getting segmented masks."); } for (ByteBuffer buffer : buffersArray) { ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, imageFormat); - segmentedMasks.add(builder.build()); + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); + confidenceMasks.add(builder.build()); + } + Optional categoryMask = Optional.empty(); + if (segmenterOptions.outputCategoryMask()) { + ByteBuffer buffer; + if (copyImage) { + buffer = ByteBuffer.allocateDirect(width * height); + if (!PacketGetter.getImageData( + packets.get(CATEGORY_MASK_OUT_STREAM_INDEX), buffer)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting category mask."); + } + } else { + buffer = + PacketGetter.getImageDataDirectly(packets.get(CATEGORY_MASK_OUT_STREAM_INDEX)); + } + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); + categoryMask = Optional.of(builder.build()); } - return ImageSegmenterResult.create( - segmentedMasks, + confidenceMasks, + categoryMask, BaseVisionTaskApi.generateResultTimestampMs( segmenterOptions.runningMode(), - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); + packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX))); } @Override @@ -174,7 +188,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { .setTaskRunningModeName(segmenterOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) - .setOutputStreams(OUTPUT_STREAMS) + .setOutputStreams(outputStreams) .setTaskOptions(segmenterOptions) .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) .build(), @@ -553,8 +567,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public abstract Builder setDisplayNamesLocale(String value); - /** The output type from image segmenter. */ - public abstract Builder setOutputType(OutputType value); + /** Whether to output category mask. */ + public abstract Builder setOutputCategoryMask(boolean value); /** * Sets an optional {@link ResultListener} to receive the segmentation results when the graph @@ -594,27 +608,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi { abstract String displayNamesLocale(); - abstract OutputType outputType(); + abstract boolean outputCategoryMask(); abstract Optional> resultListener(); abstract Optional errorListener(); - /** The output type of segmentation results. */ - public enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK - } - public static Builder builder() { return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() .setRunningMode(RunningMode.IMAGE) .setDisplayNamesLocale("en") - .setOutputType(OutputType.CATEGORY_MASK); + .setOutputCategoryMask(false); } /** @@ -633,14 +637,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.newBuilder(); - if (outputType() == OutputType.CONFIDENCE_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); - } else if (outputType() == OutputType.CATEGORY_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); - } - taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java index 69ab79c13..400894a66 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -19,6 +19,7 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.TaskResult; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Represents the segmentation results generated by {@link ImageSegmenter}. */ @AutoValue @@ -27,18 +28,24 @@ public abstract class ImageSegmenterResult implements TaskResult { /** * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage. * - * @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType - * is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is - * CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_VEC32F1 format. + * @param confidenceMasks a {@link List} of MPImage in IMAGE_FORMAT_VEC32F1 format representing + * the confidence masks, where, for each mask, each pixel represents the prediction + * confidence, usually in the [0, 1] range. + * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a + * category mask, where each pixel represents the class which the pixel in the original image + * was predicted to belong to. * @param timestampMs a timestamp for this result. */ // TODO: consolidate output formats across platforms. - public static ImageSegmenterResult create(List segmentations, long timestampMs) { + public static ImageSegmenterResult create( + List confidenceMasks, Optional categoryMask, long timestampMs) { return new AutoValue_ImageSegmenterResult( - Collections.unmodifiableList(segmentations), timestampMs); + Collections.unmodifiableList(confidenceMasks), categoryMask, timestampMs); } - public abstract List segmentations(); + public abstract List confidenceMasks(); + + public abstract Optional categoryMask(); @Override public abstract long timestampMs(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 657716b6b..2348aaadd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -133,6 +133,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), + Optional.empty(), packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); } List segmentedMasks = new ArrayList<>(); @@ -172,6 +173,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( segmentedMasks, + Optional.empty(), BaseVisionTaskApi.generateResultTimestampMs( RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java index 3b35c21bc..7acf1377e 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -61,14 +61,13 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) + .setOutputCategoryMask(true) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); - assertThat(segmentations.size()).isEqualTo(1); - MPImage actualMaskBuffer = actualResult.segmentations().get(0); + assertThat(actualResult.categoryMask().isPresent()).isTrue(); + MPImage actualMaskBuffer = actualResult.categoryMask().get(); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyCategoryMask( actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR); @@ -81,15 +80,14 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = actualResult.segmentations().get(8); + MPImage actualMaskBuffer = segmentations.get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -102,40 +100,36 @@ public class ImageSegmenterTest { ImageSegmenterOptions.builder() .setBaseOptions( BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(2); // Selfie category index 1. - MPImage actualMaskBuffer = actualResult.segmentations().get(1); + MPImage actualMaskBuffer = segmentations.get(1); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } - // TODO: enable this unit test once activation option is supported in metadata. - // @Test - // public void segment_successWith144x256Segmentation() throws Exception { - // final String inputImageName = "mozart_square.jpg"; - // final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; - // ImageSegmenterOptions options = - // ImageSegmenterOptions.builder() - // .setBaseOptions( - // BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) - // .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) - // .build(); - // ImageSegmenter imageSegmenter = - // ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); - // ImageSegmenterResult actualResult = - // imageSegmenter.segment(getImageFromAsset(inputImageName)); - // List segmentations = actualResult.segmentations(); - // assertThat(segmentations.size()).isEqualTo(1); - // MPImage actualMaskBuffer = actualResult.segmentations().get(0); - // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); - // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); - // } + @Test + public void segment_successWith144x256Segmentation() throws Exception { + final String inputImageName = "mozart_square.jpg"; + final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); + List segmentations = actualResult.confidenceMasks(); + assertThat(segmentations.size()).isEqualTo(1); + MPImage actualMaskBuffer = segmentations.get(0); + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + } @Test public void getLabels_success() throws Exception { @@ -165,7 +159,6 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -287,16 +280,15 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = actualResult.segmentations().get(8); + MPImage actualMaskBuffer = segmentations.get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -309,12 +301,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -331,7 +322,6 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .build(); ImageSegmenter imageSegmenter = @@ -341,10 +331,10 @@ public class ImageSegmenterTest { ImageSegmenterResult actualResult = imageSegmenter.segmentForVideo( getImageFromAsset(inputImageName), /* timestampsMs= */ i); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = actualResult.segmentations().get(8); + MPImage actualMaskBuffer = segmentations.get(8); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } } @@ -357,12 +347,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -384,12 +373,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -411,12 +399,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index 0d9581437..9351bc721 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -60,7 +60,10 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); MPImage image = getImageFromAsset(inputImageName); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); - List segmentations = actualResult.segmentations(); + // TODO update to correct category mask output. + // After InteractiveSegmenter updated according to (b/276519300), update this to use + // categoryMask field instead of confidenceMasks. + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(1); } @@ -79,7 +82,7 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName), roi); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(2); } } From a98f6bf231360f52ed315f27bf1f12cc38ee197b Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 4 Apr 2023 11:20:17 -0700 Subject: [PATCH 37/63] FaceDetector Web API PiperOrigin-RevId: 521816795 --- .../tasks/web/components/containers/BUILD | 1 + .../containers/detection_result.d.ts | 10 + .../processors/detection_result.test.ts | 17 +- .../components/processors/detection_result.ts | 13 ++ mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/README.md | 16 ++ .../tasks/web/vision/face_detector/BUILD | 71 ++++++ .../web/vision/face_detector/face_detector.ts | 213 ++++++++++++++++++ .../face_detector/face_detector_options.d.ts | 33 +++ .../face_detector/face_detector_result.d.ts | 19 ++ .../face_detector/face_detector_test.ts | 193 ++++++++++++++++ mediapipe/tasks/web/vision/index.ts | 3 + mediapipe/tasks/web/vision/types.ts | 1 + 13 files changed, 590 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/web/vision/face_detector/BUILD create mode 100644 mediapipe/tasks/web/vision/face_detector/face_detector.ts create mode 100644 mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts create mode 100644 mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts create mode 100644 mediapipe/tasks/web/vision/face_detector/face_detector_test.ts diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 477ca15c3..714b4613b 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_declaration( deps = [ ":bounding_box", ":category", + ":keypoint", ], ) diff --git a/mediapipe/tasks/web/components/containers/detection_result.d.ts b/mediapipe/tasks/web/components/containers/detection_result.d.ts index a338cc901..37817307c 100644 --- a/mediapipe/tasks/web/components/containers/detection_result.d.ts +++ b/mediapipe/tasks/web/components/containers/detection_result.d.ts @@ -16,6 +16,7 @@ import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; import {Category} from '../../../../tasks/web/components/containers/category'; +import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint'; /** Represents one detection by a detection task. */ export declare interface Detection { @@ -24,6 +25,15 @@ export declare interface Detection { /** The bounding box of the detected objects. */ boundingBox?: BoundingBox; + + /** + * Optional list of keypoints associated with the detection. Keypoints + * represent interesting points related to the detection. For example, the + * keypoints represent the eye, ear and mouth from face detection model. Or + * in the template matching detection, e.g. KNIFT, they can represent the + * feature points for template matching. + */ + keypoints?: NormalizedKeypoint[]; } /** Detection results of a model. */ diff --git a/mediapipe/tasks/web/components/processors/detection_result.test.ts b/mediapipe/tasks/web/components/processors/detection_result.test.ts index 26f8bd8a5..289043506 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.test.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.test.ts @@ -31,6 +31,7 @@ describe('convertFromDetectionProto()', () => { detection.addLabelId(1); detection.addLabel('foo'); detection.addDisplayName('bar'); + const locationData = new LocationData(); const boundingBox = new LocationData.BoundingBox(); boundingBox.setXmin(1); @@ -38,6 +39,14 @@ describe('convertFromDetectionProto()', () => { boundingBox.setWidth(3); boundingBox.setHeight(4); locationData.setBoundingBox(boundingBox); + + const keypoint = new LocationData.RelativeKeypoint(); + keypoint.setX(5); + keypoint.setY(6); + keypoint.setScore(0.7); + keypoint.setKeypointLabel('bar'); + locationData.addRelativeKeypoints(new LocationData.RelativeKeypoint()); + detection.setLocationData(locationData); const result = convertFromDetectionProto(detection); @@ -49,7 +58,13 @@ describe('convertFromDetectionProto()', () => { categoryName: 'foo', displayName: 'bar', }], - boundingBox: {originX: 1, originY: 2, width: 3, height: 4} + boundingBox: {originX: 1, originY: 2, width: 3, height: 4}, + keypoints: [{ + x: 5, + y: 6, + score: 0.7, + label: 'bar', + }], }); }); diff --git a/mediapipe/tasks/web/components/processors/detection_result.ts b/mediapipe/tasks/web/components/processors/detection_result.ts index 01041c915..6b38820bf 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.ts @@ -46,5 +46,18 @@ export function convertFromDetectionProto(source: DetectionProto): Detection { }; } + if (source.getLocationData()?.getRelativeKeypointsList().length) { + detection.keypoints = []; + for (const keypoint of + source.getLocationData()!.getRelativeKeypointsList()) { + detection.keypoints.push({ + x: keypoint.getX() ?? 0.0, + y: keypoint.getY() ?? 0.0, + score: keypoint.getScore() ?? 0.0, + label: keypoint.getKeypointLabel() ?? '', + }); + } + } + return detection; } diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 1f28cb0fe..19c795fd9 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -19,6 +19,7 @@ mediapipe_files(srcs = [ VISION_LIBS = [ "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/face_detector", "//mediapipe/tasks/web/vision/face_landmarker", "//mediapipe/tasks/web/vision/face_stylizer", "//mediapipe/tasks/web/vision/gesture_recognizer", diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index ebeac54c5..d5109142b 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -2,6 +2,22 @@ This package contains the vision tasks for MediaPipe. +## Face Detection + +The MediaPipe Face Detector task lets you detect the presence and location of +faces within images or videos. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const faceDetector = await FaceDetector.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const detections = faceDetector.detect(image); +``` + ## Face Landmark Detection The MediaPipe Face Landmarker task lets you detect the landmarks of faces in diff --git a/mediapipe/tasks/web/vision/face_detector/BUILD b/mediapipe/tasks/web/vision/face_detector/BUILD new file mode 100644 index 000000000..8225e4948 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/BUILD @@ -0,0 +1,71 @@ +# This contains the MediaPipe Face Detector Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more faces, using Face Detector. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "face_detector", + srcs = ["face_detector.ts"], + visibility = ["//visibility:public"], + deps = [ + ":face_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/processors:detection_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "face_detector_types", + srcs = [ + "face_detector_options.d.ts", + "face_detector_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:bounding_box", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:detection_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "face_detector_test_lib", + testonly = True, + srcs = [ + "face_detector_test.ts", + ], + deps = [ + ":face_detector", + ":face_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "face_detector_test", + tags = ["nomsan"], + deps = [":face_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector.ts b/mediapipe/tasks/web/vision/face_detector/face_detector.ts new file mode 100644 index 000000000..039f7dd44 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector.ts @@ -0,0 +1,213 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {FaceDetectorGraphOptions as FaceDetectorGraphOptionsProto} from '../../../../tasks/cc/vision/face_detector/proto/face_detector_graph_options_pb'; +import {convertFromDetectionProto} from '../../../../tasks/web/components/processors/detection_result'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {FaceDetectorOptions} from './face_detector_options'; +import {FaceDetectorResult} from './face_detector_result'; + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect_in'; +const DETECTIONS_STREAM = 'detections'; +const FACE_DETECTOR_GRAPH = + 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'; + +export * from './face_detector_options'; +export * from './face_detector_result'; +export {ImageSource}; // Used in the public API + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs face detection on images. */ +export class FaceDetector extends VisionTaskRunner { + private result: FaceDetectorResult = {detections: []}; + private readonly options = new FaceDetectorGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new face detector from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param faceDetectorOptions The options for the FaceDetector. Note that + * either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + faceDetectorOptions: FaceDetectorOptions): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, faceDetectorOptions); + } + + /** + * Initializes the Wasm runtime and creates a new face detector based on the + * provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new face detector based on the + * path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + this.options.setBaseOptions(new BaseOptionsProto()); + this.options.setMinDetectionConfidence(0.5); + this.options.setMinSuppressionThreshold(0.3); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the FaceDetector. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the FaceDetector. + */ + override setOptions(options: FaceDetectorOptions): Promise { + if ('minDetectionConfidence' in options) { + this.options.setMinDetectionConfidence( + options.minDetectionConfidence ?? 0.5); + } + if ('minSuppressionThreshold' in options) { + this.options.setMinSuppressionThreshold( + options.minSuppressionThreshold ?? 0.3); + } + return this.applyOptions(options); + } + + /** + * Performs face detection on the provided single image and waits + * synchronously for the response. Only use this method when the + * FaceDetector is created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return A result containing the list of detected faces. + */ + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + FaceDetectorResult { + this.result = {detections: []}; + this.processImageData(image, imageProcessingOptions); + return this.result; + } + + /** + * Performs face detection on the provided video frame and waits + * synchronously for the response. Only use this method when the + * FaceDetector is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return A result containing the list of detected faces. + */ + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): FaceDetectorResult { + this.result = {detections: []}; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.result; + } + + /** Converts raw data into a Detection, and adds it to our detection list. */ + private addJsFaceDetections(data: Uint8Array[]): void { + for (const binaryProto of data) { + const detectionProto = DetectionProto.deserializeBinary(binaryProto); + this.result.detections.push(convertFromDetectionProto(detectionProto)); + } + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(DETECTIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + FaceDetectorGraphOptionsProto.ext, this.options); + + const detectorNode = new CalculatorGraphConfig.Node(); + detectorNode.setCalculator(FACE_DETECTOR_GRAPH); + detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM); + detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); + detectorNode.setOptions(calculatorOptions); + + graphConfig.addNode(detectorNode); + + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, (binaryProto, timestamp) => { + this.addJsFaceDetections(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener(DETECTIONS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts new file mode 100644 index 000000000..665035f7e --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts @@ -0,0 +1,33 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Face Detector Task */ +export interface FaceDetectorOptions extends VisionTaskOptions { + /** + * The minimum confidence score for the face detection to be considered + * successful. Defaults to 0.5. + */ + minDetectionConfidence?: number|undefined; + + /** + * The minimum non-maximum-suppression threshold for face detection to be + * considered overlapped. Defaults to 0.3. + */ + minSuppressionThreshold?: number|undefined; +} diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts new file mode 100644 index 000000000..6a36559f7 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; +export {Category} from '../../../../tasks/web/components/containers/category'; +export {Detection, DetectionResult as FaceDetectorResult} from '../../../../tasks/web/components/containers/detection_result'; diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts new file mode 100644 index 000000000..88dd20d2b --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts @@ -0,0 +1,193 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {FaceDetector} from './face_detector'; +import {FaceDetectorOptions} from './face_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class FaceDetectorFake extends FaceDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('FaceDetector', () => { + let faceDetector: FaceDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + faceDetector = new FaceDetectorFake(); + await faceDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(faceDetector); + verifyListenersRegistered(faceDetector); + }); + + it('reloads graph when settings are changed', async () => { + await faceDetector.setOptions({minDetectionConfidence: 0.1}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]); + verifyListenersRegistered(faceDetector); + + await faceDetector.setOptions({minDetectionConfidence: 0.2}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.2]); + verifyListenersRegistered(faceDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await faceDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + faceDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await faceDetector.setOptions({minDetectionConfidence: 0.1}); + await faceDetector.setOptions({minSuppressionThreshold: 0.2}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]); + verifyGraph(faceDetector, ['minSuppressionThreshold', 0.2]); + }); + + describe('setOptions()', () => { + interface TestCase { + optionName: keyof FaceDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'minDetectionConfidence', + protoName: 'minDetectionConfidence', + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionName: 'minSuppressionThreshold', + protoName: 'minSuppressionThreshold', + customValue: 0.2, + defaultValue: 0.3 + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await faceDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await faceDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]); + await faceDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph(faceDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + faceDetector.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('transforms results', async () => { + const detection = new DetectionProto(); + detection.addScore(0.1); + const locationData = new LocationData(); + const boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + + const binaryProto = detection.serializeBinary(); + + // Pass the test data to our listener + faceDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(faceDetector); + faceDetector.protoListener!([binaryProto], 1337); + }); + + // Invoke the face detector + const {detections} = faceDetector.detect({} as HTMLImageElement); + + expect(faceDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(1); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 856d84683..4882e22c4 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -15,6 +15,7 @@ */ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceLandmarker as FaceLandmarkerImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; @@ -28,6 +29,7 @@ import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/ob // Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; +const FaceDetector = FaceDetectorImpl; const FaceLandmarker = FaceLandmarkerImpl; const FaceStylizer = FaceStylizerImpl; const GestureRecognizer = GestureRecognizerImpl; @@ -40,6 +42,7 @@ const ObjectDetector = ObjectDetectorImpl; export { FilesetResolver, + FaceDetector, FaceLandmarker, FaceStylizer, GestureRecognizer, diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index 2756b05a5..f49161adf 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -15,6 +15,7 @@ */ export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; From 55bcfcb4f55158795979a8ef0d87aa165b49971c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 12:28:25 -0700 Subject: [PATCH 38/63] Internal change PiperOrigin-RevId: 521834742 --- mediapipe/calculators/tensor/BUILD | 26 ++++++++++++++------------ mediapipe/calculators/tflite/BUILD | 27 ++++++--------------------- mediapipe/objc/BUILD | 1 + third_party/apple_frameworks/BUILD | 5 +++++ 4 files changed, 26 insertions(+), 33 deletions(-) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index fd926a8fe..9ae884253 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -467,10 +467,6 @@ cc_library( "-x objective-c++", "-fobjc-arc", # enable reference-counting ], - linkopts = [ - "-framework CoreVideo", - "-framework MetalKit", - ], tags = ["ios"], deps = [ "inference_calculator_interface", @@ -486,7 +482,13 @@ cc_library( "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", - ], + ] + select({ + "//mediapipe:apple": [ + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:MetalKit", + ], + "//conditions:default": [], + }), alwayslink = 1, ) @@ -721,13 +723,6 @@ cc_library( "//conditions:default": [], }), features = ["-layering_check"], # allow depending on tensors_to_detections_calculator_gpu_deps - linkopts = select({ - "//mediapipe:apple": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -744,6 +739,12 @@ cc_library( ] + selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [":tensors_to_detections_calculator_gpu_deps"], + }) + select({ + "//mediapipe:apple": [ + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:MetalKit", + ], + "//conditions:default": [], }), alwayslink = 1, ) @@ -1333,6 +1334,7 @@ cc_library( "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:MPPMetalHelper", + "//third_party/apple_frameworks:MetalKit", ], "//conditions:default": [ "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 435ea9fc1..333de2069 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -193,13 +193,6 @@ cc_library( ":edge_tpu_pci": ["MEDIAPIPE_EDGE_TPU=pci"], ":edge_tpu_all": ["MEDIAPIPE_EDGE_TPU=all"], }), - linkopts = select({ - "//mediapipe:ios": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tflite_inference_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -222,6 +215,8 @@ cc_library( "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", + "//third_party/apple_frameworks:MetalKit", + "//third_party/apple_frameworks:CoreVideo", ], "//conditions:default": [ "//mediapipe/util/tflite:tflite_gpu_runner", @@ -271,13 +266,6 @@ cc_library( ], "//conditions:default": [], }), - linkopts = select({ - "//mediapipe:ios": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tflite_converter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -296,6 +284,8 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + "//third_party/apple_frameworks:MetalKit", + "//third_party/apple_frameworks:CoreVideo", ], "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", @@ -393,13 +383,6 @@ cc_library( ], "//conditions:default": [], }), - linkopts = select({ - "//mediapipe:ios": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -420,6 +403,8 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + "//third_party/apple_frameworks:MetalKit", + "//third_party/apple_frameworks:CoreVideo", ], "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 7df6c8027..20d89d329 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -89,6 +89,7 @@ objc_library( "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", + "//third_party/apple_frameworks:AVFoundation", "//third_party/apple_frameworks:Accelerate", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD index 05f830e81..62f91b515 100644 --- a/third_party/apple_frameworks/BUILD +++ b/third_party/apple_frameworks/BUILD @@ -32,6 +32,11 @@ cc_library( linkopts = ["-framework Metal"], ) +cc_library( + name = "MetalKit", + linkopts = ["-framework MetalKit"], +) + cc_library( name = "MetalPerformanceShaders", linkopts = ["-framework MetalPerformanceShaders"], From 9554836145e528ac6a8e3abfc32271606d39c2b0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 13:37:10 -0700 Subject: [PATCH 39/63] Update java image segmenter to always output confidence masks and optionally output category mask. PiperOrigin-RevId: 521852718 --- .../vision/imagesegmenter/ImageSegmenter.java | 112 +++++++++--------- .../imagesegmenter/ImageSegmenterResult.java | 19 +-- .../InteractiveSegmenter.java | 2 - .../imagesegmenter/ImageSegmenterTest.java | 79 ++++++------ .../InteractiveSegmenterTest.java | 7 +- 5 files changed, 112 insertions(+), 107 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index e8e0e4051..b809ab963 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -79,10 +79,15 @@ public final class ImageSegmenter extends BaseVisionTaskApi { private static final List INPUT_STREAMS = Collections.unmodifiableList( Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); - private static final int CONFIDENCE_MASKS_OUT_STREAM_INDEX = 0; + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "GROUPED_SEGMENTATION:segmented_mask_out", + "IMAGE:image_out", + "SEGMENTATION:0:segmentation")); + private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; private static final int IMAGE_OUT_STREAM_INDEX = 1; - private static final int CONFIDENCE_MASK_OUT_STREAM_INDEX = 2; - private static final int CATEGORY_MASK_OUT_STREAM_INDEX = 3; + private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -99,13 +104,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public static ImageSegmenter createFromOptions( Context context, ImageSegmenterOptions segmenterOptions) { - List outputStreams = new ArrayList<>(); - outputStreams.add("CONFIDENCE_MASKS:confidence_masks"); - outputStreams.add("IMAGE:image_out"); - outputStreams.add("CONFIDENCE_MASK:0:confidence_mask"); - if (segmenterOptions.outputCategoryMask()) { - outputStreams.add("CATEGORY_MASK:category_mask"); - } // TODO: Consolidate OutputHandler and TaskRunner. OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( @@ -113,62 +111,50 @@ public final class ImageSegmenter extends BaseVisionTaskApi { @Override public ImageSegmenterResult convertToTaskResult(List packets) throws MediaPipeException { - if (packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).isEmpty()) { + if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), - Optional.empty(), - packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).getTimestamp()); + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); } - List confidenceMasks = new ArrayList<>(); - int width = PacketGetter.getImageWidth(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); - int height = PacketGetter.getImageHeight(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); - int confidenceMasksListSize = - PacketGetter.getImageListSize(packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX)); - ByteBuffer[] buffersArray = new ByteBuffer[confidenceMasksListSize]; + List segmentedMasks = new ArrayList<>(); + int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int imageFormat = + segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK + ? MPImage.IMAGE_FORMAT_VEC32F1 + : MPImage.IMAGE_FORMAT_ALPHA; + int imageListSize = + PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe // graph. If provided, the result MPImage is wrapping the mediapipe packet memory. - boolean copyImage = !segmenterOptions.resultListener().isPresent(); - if (copyImage) { - for (int i = 0; i < confidenceMasksListSize; i++) { - buffersArray[i] = ByteBuffer.allocateDirect(width * height * 4); + if (!segmenterOptions.resultListener().isPresent()) { + for (int i = 0; i < imageListSize; i++) { + buffersArray[i] = + ByteBuffer.allocateDirect( + width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); } } if (!PacketGetter.getImageList( - packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX), buffersArray, copyImage)) { + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), + buffersArray, + !segmenterOptions.resultListener().isPresent())) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting segmented masks."); + "There is an error getting segmented masks. It usually results from incorrect" + + " options of unsupported OutputType of given model."); } for (ByteBuffer buffer : buffersArray) { ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); - confidenceMasks.add(builder.build()); - } - Optional categoryMask = Optional.empty(); - if (segmenterOptions.outputCategoryMask()) { - ByteBuffer buffer; - if (copyImage) { - buffer = ByteBuffer.allocateDirect(width * height); - if (!PacketGetter.getImageData( - packets.get(CATEGORY_MASK_OUT_STREAM_INDEX), buffer)) { - throw new MediaPipeException( - MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting category mask."); - } - } else { - buffer = - PacketGetter.getImageDataDirectly(packets.get(CATEGORY_MASK_OUT_STREAM_INDEX)); - } - ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); - categoryMask = Optional.of(builder.build()); + new ByteBufferImageBuilder(buffer, width, height, imageFormat); + segmentedMasks.add(builder.build()); } + return ImageSegmenterResult.create( - confidenceMasks, - categoryMask, + segmentedMasks, BaseVisionTaskApi.generateResultTimestampMs( segmenterOptions.runningMode(), - packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX))); + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); } @Override @@ -188,7 +174,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { .setTaskRunningModeName(segmenterOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) - .setOutputStreams(outputStreams) + .setOutputStreams(OUTPUT_STREAMS) .setTaskOptions(segmenterOptions) .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) .build(), @@ -567,8 +553,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public abstract Builder setDisplayNamesLocale(String value); - /** Whether to output category mask. */ - public abstract Builder setOutputCategoryMask(boolean value); + /** The output type from image segmenter. */ + public abstract Builder setOutputType(OutputType value); /** * Sets an optional {@link ResultListener} to receive the segmentation results when the graph @@ -608,17 +594,27 @@ public final class ImageSegmenter extends BaseVisionTaskApi { abstract String displayNamesLocale(); - abstract boolean outputCategoryMask(); + abstract OutputType outputType(); abstract Optional> resultListener(); abstract Optional errorListener(); + /** The output type of segmentation results. */ + public enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK + } + public static Builder builder() { return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() .setRunningMode(RunningMode.IMAGE) .setDisplayNamesLocale("en") - .setOutputCategoryMask(false); + .setOutputType(OutputType.CATEGORY_MASK); } /** @@ -637,6 +633,14 @@ public final class ImageSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.newBuilder(); + if (outputType() == OutputType.CONFIDENCE_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); + } else if (outputType() == OutputType.CATEGORY_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); + } + taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java index 400894a66..69ab79c13 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -19,7 +19,6 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.TaskResult; import java.util.Collections; import java.util.List; -import java.util.Optional; /** Represents the segmentation results generated by {@link ImageSegmenter}. */ @AutoValue @@ -28,24 +27,18 @@ public abstract class ImageSegmenterResult implements TaskResult { /** * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage. * - * @param confidenceMasks a {@link List} of MPImage in IMAGE_FORMAT_VEC32F1 format representing - * the confidence masks, where, for each mask, each pixel represents the prediction - * confidence, usually in the [0, 1] range. - * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a - * category mask, where each pixel represents the class which the pixel in the original image - * was predicted to belong to. + * @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType + * is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is + * CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_VEC32F1 format. * @param timestampMs a timestamp for this result. */ // TODO: consolidate output formats across platforms. - public static ImageSegmenterResult create( - List confidenceMasks, Optional categoryMask, long timestampMs) { + public static ImageSegmenterResult create(List segmentations, long timestampMs) { return new AutoValue_ImageSegmenterResult( - Collections.unmodifiableList(confidenceMasks), categoryMask, timestampMs); + Collections.unmodifiableList(segmentations), timestampMs); } - public abstract List confidenceMasks(); - - public abstract Optional categoryMask(); + public abstract List segmentations(); @Override public abstract long timestampMs(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 2348aaadd..657716b6b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -133,7 +133,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), - Optional.empty(), packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); } List segmentedMasks = new ArrayList<>(); @@ -173,7 +172,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( segmentedMasks, - Optional.empty(), BaseVisionTaskApi.generateResultTimestampMs( RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java index 7acf1377e..3b35c21bc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -61,13 +61,14 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputCategoryMask(true) + .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - assertThat(actualResult.categoryMask().isPresent()).isTrue(); - MPImage actualMaskBuffer = actualResult.categoryMask().get(); + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(1); + MPImage actualMaskBuffer = actualResult.segmentations().get(0); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyCategoryMask( actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR); @@ -80,14 +81,15 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = segmentations.get(8); + MPImage actualMaskBuffer = actualResult.segmentations().get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -100,36 +102,40 @@ public class ImageSegmenterTest { ImageSegmenterOptions.builder() .setBaseOptions( BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(2); // Selfie category index 1. - MPImage actualMaskBuffer = segmentations.get(1); + MPImage actualMaskBuffer = actualResult.segmentations().get(1); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } - @Test - public void segment_successWith144x256Segmentation() throws Exception { - final String inputImageName = "mozart_square.jpg"; - final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; - ImageSegmenterOptions options = - ImageSegmenterOptions.builder() - .setBaseOptions( - BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) - .build(); - ImageSegmenter imageSegmenter = - ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); - assertThat(segmentations.size()).isEqualTo(1); - MPImage actualMaskBuffer = segmentations.get(0); - MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); - verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); - } + // TODO: enable this unit test once activation option is supported in metadata. + // @Test + // public void segment_successWith144x256Segmentation() throws Exception { + // final String inputImageName = "mozart_square.jpg"; + // final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; + // ImageSegmenterOptions options = + // ImageSegmenterOptions.builder() + // .setBaseOptions( + // BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) + // .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + // .build(); + // ImageSegmenter imageSegmenter = + // ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + // ImageSegmenterResult actualResult = + // imageSegmenter.segment(getImageFromAsset(inputImageName)); + // List segmentations = actualResult.segmentations(); + // assertThat(segmentations.size()).isEqualTo(1); + // MPImage actualMaskBuffer = actualResult.segmentations().get(0); + // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + // } @Test public void getLabels_success() throws Exception { @@ -159,6 +165,7 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -280,15 +287,16 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = segmentations.get(8); + MPImage actualMaskBuffer = actualResult.segmentations().get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -301,11 +309,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -322,6 +331,7 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .build(); ImageSegmenter imageSegmenter = @@ -331,10 +341,10 @@ public class ImageSegmenterTest { ImageSegmenterResult actualResult = imageSegmenter.segmentForVideo( getImageFromAsset(inputImageName), /* timestampsMs= */ i); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = segmentations.get(8); + MPImage actualMaskBuffer = actualResult.segmentations().get(8); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } } @@ -347,11 +357,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -373,11 +384,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -399,11 +411,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index 9351bc721..0d9581437 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -60,10 +60,7 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); MPImage image = getImageFromAsset(inputImageName); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); - // TODO update to correct category mask output. - // After InteractiveSegmenter updated according to (b/276519300), update this to use - // categoryMask field instead of confidenceMasks. - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(1); } @@ -82,7 +79,7 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName), roi); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(2); } } From f8b2aa06331bc2d2eb0b41990e5d0fd2ea24f0ac Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 17:33:45 -0700 Subject: [PATCH 40/63] Internal change PiperOrigin-RevId: 521909998 --- mediapipe/objc/BUILD | 2 + mediapipe/objc/MPPInputSource.h | 24 ++++++- mediapipe/objc/MPPPlayerInputSource.h | 10 ++- mediapipe/objc/MPPPlayerInputSource.m | 94 +++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 2 deletions(-) diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 20d89d329..83567a4d8 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -125,8 +125,10 @@ objc_library( visibility = ["//visibility:public"], deps = [ "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreAudio", "//third_party/apple_frameworks:CoreVideo", "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:MediaToolbox", ], ) diff --git a/mediapipe/objc/MPPInputSource.h b/mediapipe/objc/MPPInputSource.h index 2c518fdc4..121261c59 100644 --- a/mediapipe/objc/MPPInputSource.h +++ b/mediapipe/objc/MPPInputSource.h @@ -13,8 +13,11 @@ // limitations under the License. #import +#import #import +NS_ASSUME_NONNULL_BEGIN + @class MPPInputSource; /// A delegate that can receive frames from a source. @@ -31,7 +34,7 @@ timestamp:(CMTime)timestamp fromSource:(MPPInputSource*)source; -// Provides the delegate with a new depth frame data +// Provides the delegate with new depth frame data. @optional - (void)processDepthData:(AVDepthData*)depthData timestamp:(CMTime)timestamp @@ -40,6 +43,23 @@ @optional - (void)videoDidPlayToEnd:(CMTime)timestamp; +// Provides the delegate with the format of the audio track to be played. +@optional +- (void)willStartPlayingAudioWithFormat:(const AudioStreamBasicDescription*)format + fromSource:(MPPInputSource*)source; + +// Lets the delegate know that there is no audio track despite audio playback +// having been requested (or that audio playback failed for other reasons). +@optional +- (void)noAudioAvailableFromSource:(MPPInputSource*)source; + +// Provides the delegate with a new audio packet. +@optional +- (void)processAudioPacket:(const AudioBufferList*)audioPacket + numFrames:(CMItemCount)numFrames + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source; + @end /// Abstract class for a video source. @@ -68,3 +88,5 @@ - (void)stop; @end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/objc/MPPPlayerInputSource.h b/mediapipe/objc/MPPPlayerInputSource.h index e1516abe9..87054d953 100644 --- a/mediapipe/objc/MPPPlayerInputSource.h +++ b/mediapipe/objc/MPPPlayerInputSource.h @@ -18,7 +18,15 @@ /// Not meant for batch processing of video. @interface MPPPlayerInputSource : MPPInputSource -/// Designated initializer. +/// Initializes the video source with optional audio processing. +/// +/// @param video The video asset to play. +/// @param audioProcessingEnabled If set, indicates that the (first) audio track +/// should be processed if it exists, and the corresponding methods for +/// audio will be invoked on the @c delegate. +- (instancetype)initWithAVAsset:(AVAsset*)video audioProcessingEnabled:(BOOL)audioProcessingEnabled; + +/// Initializes the video source to process @c video with audio processing disabled. - (instancetype)initWithAVAsset:(AVAsset*)video; /// Skip into video @c time from beginning (time 0), within error of +/- tolerance to closest time. diff --git a/mediapipe/objc/MPPPlayerInputSource.m b/mediapipe/objc/MPPPlayerInputSource.m index f5741f8af..6cd489ff7 100644 --- a/mediapipe/objc/MPPPlayerInputSource.m +++ b/mediapipe/objc/MPPPlayerInputSource.m @@ -13,11 +13,13 @@ // limitations under the License. #import +#import #import "MPPPlayerInputSource.h" #if !TARGET_OS_OSX #import "mediapipe/objc/MPPDisplayLinkWeakTarget.h" #endif +#import "mediapipe/objc/MPPInputSource.h" @implementation MPPPlayerInputSource { AVAsset* _video; @@ -35,7 +37,53 @@ BOOL _playing; } +void InitAudio(MTAudioProcessingTapRef tap, void* clientInfo, void** tapStorageOut) { + // `clientInfo` comes as a user-defined argument through + // `MTAudioProcessingTapCallbacks`; we pass our `MPPPlayerInputSource` + // there. Tap processing functions allow for user-defined "storage" - we just + // treat our input source as such. + *tapStorageOut = clientInfo; +} + +void PrepareAudio(MTAudioProcessingTapRef tap, CMItemCount maxFrames, + const AudioStreamBasicDescription* audioFormat) { + // See `InitAudio`. + MPPPlayerInputSource* source = + (__bridge MPPPlayerInputSource*)MTAudioProcessingTapGetStorage(tap); + if ([source.delegate respondsToSelector:@selector(willStartPlayingAudioWithFormat:fromSource:)]) { + [source.delegate willStartPlayingAudioWithFormat:audioFormat fromSource:source]; + } +} + +void ProcessAudio(MTAudioProcessingTapRef tap, CMItemCount numberFrames, + MTAudioProcessingTapFlags flags, AudioBufferList* bufferListInOut, + CMItemCount* numberFramesOut, MTAudioProcessingTapFlags* flagsOut) { + CMTimeRange timeRange; + OSStatus status = MTAudioProcessingTapGetSourceAudio(tap, numberFrames, bufferListInOut, flagsOut, + &timeRange, numberFramesOut); + if (status != 0) { + NSLog(@"Error from GetSourceAudio: %ld", (long)status); + return; + } + + // See `InitAudio`. + MPPPlayerInputSource* source = + (__bridge MPPPlayerInputSource*)MTAudioProcessingTapGetStorage(tap); + if ([source.delegate respondsToSelector:@selector(processAudioPacket: + numFrames:timestamp:fromSource:)]) { + [source.delegate processAudioPacket:bufferListInOut + numFrames:numberFrames + timestamp:timeRange.start + fromSource:source]; + } +} + - (instancetype)initWithAVAsset:(AVAsset*)video { + return [self initWithAVAsset:video audioProcessingEnabled:NO]; +} + +- (instancetype)initWithAVAsset:(AVAsset*)video + audioProcessingEnabled:(BOOL)audioProcessingEnabled { self = [super init]; if (self) { _video = video; @@ -67,6 +115,11 @@ CVDisplayLinkStop(_videoDisplayLink); CVDisplayLinkSetOutputCallback(_videoDisplayLink, renderCallback, (__bridge void*)self); #endif // TARGET_OS_OSX + + if (audioProcessingEnabled) { + [self setupAudioPlayback]; + } + _videoPlayer = [AVPlayer playerWithPlayerItem:_videoItem]; _videoPlayer.actionAtItemEnd = AVPlayerActionAtItemEndNone; NSNotificationCenter* center = [NSNotificationCenter defaultCenter]; @@ -88,6 +141,47 @@ return self; } +- (void)setupAudioPlayback { + bool have_audio = false; + NSArray* audioTracks = + [_video tracksWithMediaCharacteristic:AVMediaCharacteristicAudible]; + if (audioTracks.count != 0) { + // We always limit ourselves to the first audio track if there are + // multiple (which is a rarity) - note that it can still be e.g. stereo. + AVAssetTrack* audioTrack = audioTracks[0]; + MTAudioProcessingTapCallbacks audioCallbacks; + audioCallbacks.version = kMTAudioProcessingTapCallbacksVersion_0; + audioCallbacks.clientInfo = (__bridge void*)(self); + audioCallbacks.init = InitAudio; + audioCallbacks.prepare = PrepareAudio; + audioCallbacks.process = ProcessAudio; + audioCallbacks.unprepare = NULL; + audioCallbacks.finalize = NULL; + + MTAudioProcessingTapRef audioTap; + OSStatus status = + MTAudioProcessingTapCreate(kCFAllocatorDefault, &audioCallbacks, + kMTAudioProcessingTapCreationFlag_PreEffects, &audioTap); + if (status == noErr && audioTap != NULL) { + AVMutableAudioMixInputParameters* audioMixInputParams = + [AVMutableAudioMixInputParameters audioMixInputParametersWithTrack:audioTrack]; + audioMixInputParams.audioTapProcessor = audioTap; + CFRelease(audioTap); + + AVMutableAudioMix* audioMix = [AVMutableAudioMix audioMix]; + + audioMix.inputParameters = @[ audioMixInputParams ]; + _videoItem.audioMix = audioMix; + have_audio = true; + } else { + NSLog(@"Error %ld when trying to create the audio processing tap", (long)status); + } + } + if (!have_audio && [self.delegate respondsToSelector:@selector(noAudioAvailableFromSource:)]) { + [self.delegate noAudioAvailableFromSource:self]; + } +} + - (void)start { [_videoPlayer play]; _playing = YES; From 190be2e1bd064eae88ad6e72d0a17d46f3b11a5a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 17:41:58 -0700 Subject: [PATCH 41/63] Internal change PiperOrigin-RevId: 521911790 --- .../segmentation_postprocessor_gl.cc | 354 ++++++++++++------ .../segmentation_postprocessor_gl.h | 19 +- 2 files changed, 262 insertions(+), 111 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc index 5a09d3a8d..da5dcacae 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc @@ -22,7 +22,21 @@ using mediapipe::kBasicVertexShader; using ::mediapipe::tasks::vision::Shape; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; +// TODO: This part of the setup code is so common, we should really +// refactor to a helper utility. enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, +}; +const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", +}; + +// We assume ES3.0+ for some of our shaders here so we can make liberal use of +// MRT easily. +static constexpr char kEs30RequirementHeader[] = "#version 300 es\n"; static constexpr char kActivationFragmentShader[] = R"( DEFAULT_PRECISION(mediump, float) @@ -140,55 +154,93 @@ void main() { gl_FragColor = vec4(out_value, out_value, out_value, out_value); })"; -// Quick softmax shader hardcoded to max of N=12 classes. Performs softmax -// calculations, but renders to one chunk at a time. -// TODO: For more efficiency, should at least use MRT to render all -// chunks simultaneously. -static constexpr char kSoftmaxShader[] = R"( +// Softmax is in 3 steps: +// - First we find max over all masks +// - Then we transform all masks to be exp(val - maxval), and also add to +// cumulative-sum image with MRT +// - Then we normalize all masks by cumulative-sum image + +// Part one: max shader +// To start with, we just do this chunk by chunk, using GL_MAX blend mode so we +// don't need to tap into the max-so-far texture. +static constexpr char kMaxShader[] = R"( DEFAULT_PRECISION(mediump, float) in vec2 sample_coordinate; -uniform sampler2D input_texture0; -uniform sampler2D input_texture1; -uniform sampler2D input_texture2; -uniform int chunk_select; +uniform sampler2D current_chunk; +uniform int num_channels; // how many channels from current chunk to use (1-4) float max4(vec4 vec) { return max(max(vec.x, vec.y), max(vec.z, vec.w)); } - -vec4 expTransform(vec4 vec, float maxval) { - return exp(vec - maxval); +float max3(vec4 vec) { + return max(max(vec.x, vec.y), vec.z); } +float max2(vec4 vec) { + return max(vec.x, vec.y); +} +void main() { + vec4 chunk_pixel = texture2D(current_chunk, sample_coordinate); + float new_max; + if (num_channels == 1) { + new_max = chunk_pixel.x; + } else if (num_channels == 2) { + new_max = max2(chunk_pixel); + } else if (num_channels == 3) { + new_max = max3(chunk_pixel); + } else { + new_max = max4(chunk_pixel); + } + gl_FragColor = vec4(new_max, 0.0, 0.0, 1.0); +})"; + +// Part two: transform-and-sum shader +// We use GL blending so we can more easily render a cumulative sum texture, and +// this only costs us a glClear for the output chunk (needed since using MRT). +static constexpr char kTransformAndSumShader[] = R"( +DEFAULT_PRECISION(highp, float) +in vec2 sample_coordinate; +uniform sampler2D max_value_texture; +uniform sampler2D current_chunk; +uniform int num_channels; // how many channels from current chunk to use (1-4) + +layout(location = 0) out vec4 cumulative_sum_texture; +layout(location = 1) out vec4 out_chunk_texture; void main() { - // Grab all vecs - vec4 pixel0 = texture2D(input_texture0, sample_coordinate); - vec4 pixel1 = texture2D(input_texture1, sample_coordinate); - vec4 pixel2 = texture2D(input_texture2, sample_coordinate); + float max_pixel = texture(max_value_texture, sample_coordinate).r; + vec4 chunk_pixel = texture(current_chunk, sample_coordinate); + vec4 new_chunk_pixel = exp(chunk_pixel - max_pixel); - // Find maxval amongst all vectors - float max0 = max4(pixel0); - float max1 = max4(pixel1); - float max2 = max4(pixel2); - float maxval = max(max(max0, max1), max2); + float sum_so_far; + if (num_channels == 1) { + sum_so_far = new_chunk_pixel.x; + } else if (num_channels == 2) { + sum_so_far = dot(vec2(1.0, 1.0), new_chunk_pixel.xy); + } else if (num_channels == 3) { + sum_so_far = dot(vec3(1.0, 1.0, 1.0), new_chunk_pixel.xyz); + } else { + sum_so_far = dot(vec4(1.0, 1.0, 1.0, 1.0), new_chunk_pixel); + } - vec4 outPixel0 = expTransform(pixel0, maxval); - vec4 outPixel1 = expTransform(pixel1, maxval); - vec4 outPixel2 = expTransform(pixel2, maxval); + cumulative_sum_texture = vec4(sum_so_far, 0.0, 0.0, 1.0); + out_chunk_texture = new_chunk_pixel; +})"; - // Quick hack to sum all components in vec4: dot with <1, 1, 1, 1> - vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); - float weightSum = dot(ones, outPixel0) + dot(ones, outPixel1) + dot(ones, outPixel2); +// Part three: normalization shader +static constexpr char kNormalizationShader[] = R"( +DEFAULT_PRECISION(mediump, float) +in vec2 sample_coordinate; +uniform sampler2D sum_texture; // cumulative summation value (to normalize by) +uniform sampler2D current_chunk; // current chunk - vec4 outPixel; - if (chunk_select == 0) { - outPixel = outPixel0 / weightSum; - } else if (chunk_select == 1) { - outPixel = outPixel1 / weightSum; - } else { - outPixel = outPixel2 / weightSum; - } - gl_FragColor = outPixel; +void main() { + float sum_pixel = texture2D(sum_texture, sample_coordinate).r; + vec4 chunk_pixel = texture2D(current_chunk, sample_coordinate); + + // NOTE: We assume non-zero sum_pixel here, which is a safe assumption for + // result of an exp transform, but not if this shader is extended to other + // uses. + gl_FragColor = chunk_pixel / sum_pixel; })"; } // namespace @@ -208,19 +260,38 @@ absl::Status SegmentationPostprocessorGl::Initialize( return absl::OkStatus(); } +absl::Status SegmentationPostprocessorGl::CreateBasicFragmentShaderProgram( + std::string const& program_name, std::string const& fragment_shader_source, + std::vector const& uniform_names, GlShader* shader_struct_ptr, + bool is_es30_only = false) { + // Format source and create basic ES3.0+ fragment-shader-only program + const std::string frag_shader_source = + absl::StrCat(is_es30_only ? std::string(kEs30RequirementHeader) : "", + std::string(mediapipe::kMediaPipeFragmentShaderPreamble), + std::string(fragment_shader_source)); + const std::string vert_shader_source = + absl::StrCat(is_es30_only ? std::string(kEs30RequirementHeader) : "", + std::string(kBasicVertexShader)); + mediapipe::GlhCreateProgram( + vert_shader_source.c_str(), frag_shader_source.c_str(), NUM_ATTRIBUTES, + &attr_name[0], attr_location, &shader_struct_ptr->program, + /* force_log_errors */ true); + RET_CHECK(shader_struct_ptr->program) + << "Problem initializing the " << program_name << " program."; + + // Hook up all desired uniforms + for (const auto& uniform_name : uniform_names) { + shader_struct_ptr->uniforms[uniform_name] = + glGetUniformLocation(shader_struct_ptr->program, uniform_name.c_str()); + RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > 0) + << uniform_name << " uniform not found for " << program_name + << " program"; + } + return absl::OkStatus(); +} + absl::Status SegmentationPostprocessorGl::GlInit() { return helper_.RunInGlContext([this]() -> absl::Status { - // TODO: This part of the setup code is so common, we should really - // refactor to a helper utility. - const GLint attr_location[NUM_ATTRIBUTES] = { - ATTRIB_VERTEX, - ATTRIB_TEXTURE_POSITION, - }; - const GLchar* attr_name[NUM_ATTRIBUTES] = { - "position", - "texture_coordinate", - }; - // Default to passthrough/NONE std::string activation_fn = "vec4 out_value = in_value;"; switch (options_.segmenter_options().activation()) { @@ -263,9 +334,17 @@ absl::Status SegmentationPostprocessorGl::GlInit() { absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble), std::string(kArgmaxShader)); - const std::string softmax_shader_source = - absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble), - std::string(kSoftmaxShader)); + // Softmax shaders (Max, Transform+Sum, and Normalization) + MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( + "softmax max", kMaxShader, {"current_chunk", "num_channels"}, + &softmax_max_shader_)); + MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( + "softmax transform-and-sum", kTransformAndSumShader, + {"max_value_texture", "current_chunk", "num_channels"}, + &softmax_transform_and_sum_shader_, true /* is_es30_only */)); + MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( + "softmax normalization", kNormalizationShader, + {"sum_texture", "current_chunk"}, &softmax_normalization_shader_)); // Compile all our shader programs. // Note: we enable `force_log_errors` so that we get full debugging error @@ -299,12 +378,6 @@ absl::Status SegmentationPostprocessorGl::GlInit() { /* force_log_errors */ true); RET_CHECK(argmax_program_) << "Problem initializing the argmax program."; - mediapipe::GlhCreateProgram(kBasicVertexShader, - softmax_shader_source.c_str(), NUM_ATTRIBUTES, - &attr_name[0], attr_location, &softmax_program_, - /* force_log_errors */ true); - RET_CHECK(softmax_program_) << "Problem initializing the softmax program."; - // Get uniform locations. activation_texture_uniform_ = glGetUniformLocation(activation_program_, "input_texture"); @@ -341,23 +414,6 @@ absl::Status SegmentationPostprocessorGl::GlInit() { RET_CHECK(argmax_texture2_uniform_ > 0) << "argmax input_texture2 uniform not found."; - softmax_texture0_uniform_ = - glGetUniformLocation(softmax_program_, "input_texture0"); - RET_CHECK(softmax_texture0_uniform_ > 0) - << "softmax input_texture0 uniform not found."; - softmax_texture1_uniform_ = - glGetUniformLocation(softmax_program_, "input_texture1"); - RET_CHECK(softmax_texture1_uniform_ > 0) - << "softmax input_texture1 uniform not found."; - softmax_texture2_uniform_ = - glGetUniformLocation(softmax_program_, "input_texture2"); - RET_CHECK(softmax_texture2_uniform_ > 0) - << "softmax input_texture2 uniform not found."; - softmax_chunk_select_uniform_ = - glGetUniformLocation(softmax_program_, "chunk_select"); - RET_CHECK(softmax_chunk_select_uniform_ > 0) - << "softmax chunk select uniform not found."; - // TODO: If ES3.0+ only, switch to VAO for handling attributes. glGenBuffers(1, &square_vertices_); glBindBuffer(GL_ARRAY_BUFFER, square_vertices_); @@ -408,6 +464,9 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, // Uint8 pipeline and conversions are lacking, so for now we just use F32 // textures even for category masks. + // TODO: Also, some platforms (like certain iOS devices) do not + // allow for rendering to RGBAF32 textures, so we should switch to using + // F16 textures in those instances. const GpuBufferFormat final_output_format = GpuBufferFormat::kGrayFloat32; const Tensor::OpenGlTexture2dView read_view = tensor.GetOpenGlTexture2dReadView(); @@ -467,7 +526,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, ((float)i + tex_offset) / (float)(input_width)); // Technically duplicated, but fine for now; we want this after the bind glBindTexture(GL_TEXTURE_2D, activated_texture.name()); - // Disable HW interpolation + // Disable hardware GPU interpolation glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); // Render @@ -477,45 +536,126 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, std::vector softmax_chunks; if (is_softmax) { - // Step 2.5: For SOFTMAX, apply softmax shader with up to 3 textures to - // create softmax-transformed chunks before channel extraction. - RET_CHECK(num_chunks <= 3) - << "Cannot handle more than 12 classes in softmax shader."; + // Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and + // normalization) to create softmax-transformed chunks before channel + // extraction. + // NOTE: exp(x-C) / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So + // theoretically we can skip the max shader step entirely. However, + // applying it does bring all our values into a nice (0, 1] range, so it + // will likely be better for precision, especially when dealing with an + // exponential function on arbitrary values. Therefore, we keep it, but + // this is potentially a skippable step for known "good" models, if we + // ever want to provide that as an option. + // TODO: For a tiny bit more efficiency, could combine channel + // extraction into last step of this via MRT. - glUseProgram(softmax_program_); - glUniform1i(softmax_texture0_uniform_, 1); - glUniform1i(softmax_texture1_uniform_, 2); - glUniform1i(softmax_texture2_uniform_, 3); + // Max + glUseProgram(softmax_max_shader_.program); + glUniform1i(softmax_max_shader_.uniforms["current_chunk"], 1); + + // We just need one channel, so format will match final output confidence + // masks + auto max_texture = + helper_.CreateDestinationTexture(width, height, final_output_format); + helper_.BindFramebuffer(max_texture); + + // We clear our newly-created destination texture to a reasonable minimum. + glClearColor(0.0, 0.0, 0.0, 0.0); + glClear(GL_COLOR_BUFFER_BIT); + + // We will use hardware GPU blending to apply max to all our writes. + glEnable(GL_BLEND); + glBlendEquation(GL_MAX); + + glActiveTexture(GL_TEXTURE1); + for (int i = 0; i < num_chunks; i++) { + int num_channels = 4; + if ((i + 1) * 4 > num_outputs) num_channels = num_outputs % 4; + glUniform1i(softmax_max_shader_.uniforms["num_channels"], num_channels); + glBindTexture(GL_TEXTURE_2D, chunks[i].name()); + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + } + + // Transform & Sum + std::vector unnormalized_softmax_chunks; + glUseProgram(softmax_transform_and_sum_shader_.program); + glUniform1i(softmax_transform_and_sum_shader_.uniforms["current_chunk"], + 1); + glUniform1i( + softmax_transform_and_sum_shader_.uniforms["max_value_texture"], 2); + + auto sum_texture = + helper_.CreateDestinationTexture(width, height, final_output_format); + helper_.BindFramebuffer(sum_texture); + glClear(GL_COLOR_BUFFER_BIT); + + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, max_texture.name()); + + glBlendEquation(GL_FUNC_ADD); + glBlendFunc(GL_ONE, GL_ONE); + glActiveTexture(GL_TEXTURE1); + + // We use glDrawBuffers to clear only the new texture, then again to + // draw to both textures simultaneously for rendering. + GLuint both_attachments[2] = {GL_COLOR_ATTACHMENT0, GL_COLOR_ATTACHMENT1}; + GLuint one_attachment[2] = {GL_NONE, GL_COLOR_ATTACHMENT1}; + for (int i = 0; i < num_chunks; i++) { + int num_channels = 4; + if ((i + 1) * 4 > num_outputs) num_channels = num_outputs % 4; + glUniform1i(softmax_transform_and_sum_shader_.uniforms["num_channels"], + num_channels); + unnormalized_softmax_chunks.push_back(helper_.CreateDestinationTexture( + width, height, chunk_output_format)); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT1, + GL_TEXTURE_2D, + unnormalized_softmax_chunks.back().name(), 0); + + // Note that we must bind AFTER the CreateDestinationTexture, or else we + // end up with (0, 0, 0, 1) data being read from an unbound texture + // unit. + glBindTexture(GL_TEXTURE_2D, chunks[i].name()); + + // Clear *only* the new chunk + glDrawBuffers(2, one_attachment); + glClear(GL_COLOR_BUFFER_BIT); + + // Then draw into both + glDrawBuffers(2, both_attachments); + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + } + + // Turn off MRT and blending, and unbind second color attachment + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT1, + GL_TEXTURE_2D, 0, 0); + glDrawBuffers(1, both_attachments); + glDisable(GL_BLEND); + + // Normalize each chunk into a new chunk as our final step + glUseProgram(softmax_normalization_shader_.program); + glUniform1i(softmax_normalization_shader_.uniforms["current_chunk"], 1); + glUniform1i(softmax_normalization_shader_.uniforms["sum_texture"], 2); + + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, sum_texture.name()); + glActiveTexture(GL_TEXTURE1); for (int i = 0; i < num_chunks; i++) { - glUniform1i(softmax_chunk_select_uniform_, i); softmax_chunks.push_back(helper_.CreateDestinationTexture( - output_width, output_height, chunk_output_format)); + width, height, chunk_output_format)); helper_.BindFramebuffer(softmax_chunks.back()); - - // Bind however many chunks we have - for (int j = 0; j < num_chunks; ++j) { - glActiveTexture(GL_TEXTURE1 + j); - glBindTexture(GL_TEXTURE_2D, chunks[j].name()); - } - - for (int j = num_chunks; j < 3; ++j) { // 3 is hard-coded max chunks - glActiveTexture(GL_TEXTURE1 + j); - // If texture is unbound, sampling from it should always give zeros. - // This is not ideal, but is ok for now for not polluting the argmax - // shader results too much. - glBindTexture(GL_TEXTURE_2D, 0); - } - + glBindTexture(GL_TEXTURE_2D, unnormalized_softmax_chunks[i].name()); glClear(GL_COLOR_BUFFER_BIT); glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); } - // Unbind the extra textures here. - for (int i = 0; i < num_chunks; ++i) { - glActiveTexture(GL_TEXTURE1 + i); - glBindTexture(GL_TEXTURE_2D, 0); - } + // Unbind textures here + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, 0); + // We make sure to switch back to texture unit 1, since our confidence + // mask extraction code assumes that's our default. + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); } std::vector outputs; @@ -607,17 +747,19 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() { glDeleteProgram(activation_program_); glDeleteProgram(argmax_program_); glDeleteProgram(channel_select_program_); - glDeleteProgram(softmax_program_); glDeleteProgram(split_program_); glDeleteBuffers(1, &square_vertices_); glDeleteBuffers(1, &texture_vertices_); activation_program_ = 0; argmax_program_ = 0; channel_select_program_ = 0; - softmax_program_ = 0; split_program_ = 0; square_vertices_ = 0; texture_vertices_ = 0; + + glDeleteProgram(softmax_max_shader_.program); + glDeleteProgram(softmax_transform_and_sum_shader_.program); + glDeleteProgram(softmax_normalization_shader_.program); }); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h index aceb3c8d6..c50f93077 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h @@ -38,7 +38,17 @@ class SegmentationPostprocessorGl { const Tensor& tensor); private: + struct GlShader { + GLuint program = 0; + absl::flat_hash_map uniforms; + }; + absl::Status GlInit(); + absl::Status CreateBasicFragmentShaderProgram( + std::string const& program_name, + std::string const& fragment_shader_source, + std::vector const& uniform_names, + GlShader* shader_struct_ptr, bool is_es30_only); TensorsToSegmentationCalculatorOptions options_; GlCalculatorHelper helper_; @@ -47,7 +57,6 @@ class SegmentationPostprocessorGl { GLuint activation_program_ = 0; GLuint argmax_program_ = 0; GLuint channel_select_program_ = 0; - GLuint softmax_program_ = 0; GLuint split_program_ = 0; GLuint square_vertices_ = 0; GLuint texture_vertices_ = 0; @@ -57,12 +66,12 @@ class SegmentationPostprocessorGl { GLint argmax_texture2_uniform_; GLint channel_select_texture_uniform_; GLint channel_select_index_uniform_; - GLint softmax_texture0_uniform_; - GLint softmax_texture1_uniform_; - GLint softmax_texture2_uniform_; - GLint softmax_chunk_select_uniform_; GLint split_texture_uniform_; GLint split_x_offset_uniform_; + + GlShader softmax_max_shader_; + GlShader softmax_transform_and_sum_shader_; + GlShader softmax_normalization_shader_; }; } // namespace tasks From 7cb8c647ca14aa69be70fb3fb00dbb1f5ef46dff Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 21:22:18 -0700 Subject: [PATCH 42/63] Internal change PiperOrigin-RevId: 521948037 --- mediapipe/util/tensor_to_detection.cc | 2 +- mediapipe/util/time_series_util.cc | 6 +++--- mediapipe/util/time_series_util_test.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/util/tensor_to_detection.cc b/mediapipe/util/tensor_to_detection.cc index 0b3d1f68a..1c19b3510 100644 --- a/mediapipe/util/tensor_to_detection.cc +++ b/mediapipe/util/tensor_to_detection.cc @@ -87,7 +87,7 @@ Status TensorsToDetections(const ::tensorflow::Tensor& num_detections, const auto& num_boxes_scalar = num_detections.scalar(); num_boxes = static_cast(num_boxes_scalar()); } else { - num_boxes = num_detections.scalar()(); + num_boxes = num_detections.scalar()(); } if (boxes.dim_size(0) < num_boxes) { return InvalidArgumentError( diff --git a/mediapipe/util/time_series_util.cc b/mediapipe/util/time_series_util.cc index 1e20daa59..87f69475a 100644 --- a/mediapipe/util/time_series_util.cc +++ b/mediapipe/util/time_series_util.cc @@ -29,7 +29,7 @@ namespace time_series_util { bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp, const Timestamp& initial_timestamp, - int64 cumulative_samples, + int64_t cumulative_samples, double sample_rate) { // Ignore the "special" timestamp value Done(). if (current_timestamp == Timestamp::Done()) return true; @@ -122,11 +122,11 @@ absl::Status IsMatrixShapeConsistentWithHeader(const Matrix& matrix, return absl::OkStatus(); } -int64 SecondsToSamples(double time_in_seconds, double sample_rate) { +int64_t SecondsToSamples(double time_in_seconds, double sample_rate) { return round(time_in_seconds * sample_rate); } -double SamplesToSeconds(int64 num_samples, double sample_rate) { +double SamplesToSeconds(int64_t num_samples, double sample_rate) { DCHECK_NE(sample_rate, 0.0); return (num_samples / sample_rate); } diff --git a/mediapipe/util/time_series_util_test.cc b/mediapipe/util/time_series_util_test.cc index 807bc4f03..e8d47dbc6 100644 --- a/mediapipe/util/time_series_util_test.cc +++ b/mediapipe/util/time_series_util_test.cc @@ -186,7 +186,7 @@ TEST(TimeSeriesUtilTest, SecondsToSamples) { TEST(TimeSeriesUtilTest, SamplesToSeconds) { double sample_rate = 32.5; - int64 num_samples = 128; + int64_t num_samples = 128; EXPECT_EQ(num_samples / sample_rate, SamplesToSeconds(num_samples, sample_rate)); } From 1990fe00d3680fb04e4157dc1da178b084d6bc7a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 21:32:34 -0700 Subject: [PATCH 43/63] Internal change PiperOrigin-RevId: 521949598 --- mediapipe/framework/deps/mathutil_unittest.cc | 124 +++++++++--------- .../framework/deps/monotonic_clock_test.cc | 8 +- mediapipe/framework/deps/safe_int_test.cc | 96 +++++++------- 3 files changed, 114 insertions(+), 114 deletions(-) diff --git a/mediapipe/framework/deps/mathutil_unittest.cc b/mediapipe/framework/deps/mathutil_unittest.cc index 7468e927a..b25b73306 100644 --- a/mediapipe/framework/deps/mathutil_unittest.cc +++ b/mediapipe/framework/deps/mathutil_unittest.cc @@ -75,17 +75,17 @@ BENCHMARK(BM_IntCast); static void BM_Int64Cast(benchmark::State& state) { double x = 0.1; - int64 sum = 0; + int64_t sum = 0; for (auto _ : state) { - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -134,15 +134,15 @@ static void BM_Int64Round(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -153,15 +153,15 @@ static void BM_UintRound(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -191,15 +191,15 @@ static void BM_SafeInt64Cast(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -229,15 +229,15 @@ static void BM_SafeInt64Round(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -262,8 +262,8 @@ TEST(MathUtil, IntRound) { // A double-precision number has a 53-bit mantissa (52 fraction bits), // so the following value can be represented exactly. - int64 value64 = static_cast(0x1234567890abcd00); - EXPECT_EQ(mediapipe::MathUtil::Round(static_cast(value64)), + int64_t value64 = static_cast(0x1234567890abcd00); + EXPECT_EQ(mediapipe::MathUtil::Round(static_cast(value64)), value64); } @@ -369,7 +369,7 @@ class SafeCastTester { if (sizeof(FloatIn) >= 64) { // A double-precision number has a 53-bit mantissa (52 fraction bits), // so the following value can be represented exactly by a double. - int64 value64 = static_cast(0x1234567890abcd00); + int64_t value64 = static_cast(0x1234567890abcd00); const IntOut expected = (sizeof(IntOut) >= 64) ? static_cast(value64) : imax; EXPECT_EQ( @@ -536,22 +536,22 @@ class SafeCastTester { }; TEST(MathUtil, SafeCast) { - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); // Spot-check SafeCast EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(12345.678)), @@ -682,7 +682,7 @@ class SafeRoundTester { if (sizeof(FloatIn) >= 64) { // A double-precision number has a 53-bit mantissa (52 fraction bits), // so the following value can be represented exactly by a double. - int64 value64 = static_cast(0x1234567890abcd00); + int64_t value64 = static_cast(0x1234567890abcd00); const IntOut expected = (sizeof(IntOut) >= 64) ? static_cast(value64) : imax; EXPECT_EQ( @@ -843,22 +843,22 @@ class SafeRoundTester { }; TEST(MathUtil, SafeRound) { - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); // Spot-check SafeRound EXPECT_EQ(mediapipe::MathUtil::SafeRound(static_cast(12345.678)), diff --git a/mediapipe/framework/deps/monotonic_clock_test.cc b/mediapipe/framework/deps/monotonic_clock_test.cc index 533830e43..0a049392f 100644 --- a/mediapipe/framework/deps/monotonic_clock_test.cc +++ b/mediapipe/framework/deps/monotonic_clock_test.cc @@ -244,7 +244,7 @@ TEST_F(MonotonicClockTest, RealTime) { // Call mono_clock->Now() continuously for FLAGS_real_test_secs seconds. absl::Time start = absl::Now(); absl::Time time = start; - int64 num_calls = 0; + int64_t num_calls = 0; do { absl::Time last_time = time; time = mono_clock->TimeNow(); @@ -406,7 +406,7 @@ class ClockFrenzy { while (Running()) { // 40% of the time, advance a simulated clock. // 50% of the time, read a monotonic clock. - const int32 u = UniformRandom(100); + const int32_t u = UniformRandom(100); if (u < 40) { // Pick a simulated clock and advance it. const int nclocks = sim_clocks_.size(); @@ -463,9 +463,9 @@ class ClockFrenzy { // Thread-safe random number generation functions for use by other class // member functions. - int32 UniformRandom(int32 n) { + int32_t UniformRandom(int32_t n) { absl::MutexLock l(&lock_); - return std::uniform_int_distribution(0, n - 1)(*random_); + return std::uniform_int_distribution(0, n - 1)(*random_); } float RndFloatRandom() { diff --git a/mediapipe/framework/deps/safe_int_test.cc b/mediapipe/framework/deps/safe_int_test.cc index 7f385848f..83932d551 100644 --- a/mediapipe/framework/deps/safe_int_test.cc +++ b/mediapipe/framework/deps/safe_int_test.cc @@ -20,21 +20,21 @@ #include "mediapipe/framework/port/gtest.h" -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt8, int8, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt8, int8_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt8, uint8, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt8, uint8_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt16, int16, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt16, int16_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt16, uint16, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt16, uint16_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt32, int32, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt32, int32_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt64, int64, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt64, int64_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt32, uint32, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt32, uint32_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt64, uint64, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt64, uint64_t, mediapipe::intops::LogFatalOnError); namespace mediapipe { @@ -102,8 +102,8 @@ TYPED_TEST(SignNeutralSafeIntTest, TestCtorFailures) { typedef typename T::ValueType V; { // Test out-of-bounds construction. - if (std::numeric_limits::is_signed || sizeof(V) < sizeof(uint64)) { - EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); + if (std::numeric_limits::is_signed || sizeof(V) < sizeof(uint64_t)) { + EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); } } { // Test out-of-bounds construction from float. @@ -233,20 +233,20 @@ TYPED_TEST(SignNeutralSafeIntTest, TestMultiply) { typedef typename T::ValueType V; // Test positive vs. positive multiplication across types. - TEST_T_OP_NUM(9, *, int32, 3); - TEST_T_OP_NUM(9, *, uint32, 3); + TEST_T_OP_NUM(9, *, int32_t, 3); + TEST_T_OP_NUM(9, *, uint32_t, 3); TEST_T_OP_NUM(9, *, float, 3); TEST_T_OP_NUM(9, *, double, 3); // Test positive vs. zero multiplication commutatively across types. This // was a real bug. - TEST_T_OP_NUM(93, *, int32, 0); - TEST_T_OP_NUM(93, *, uint32, 0); + TEST_T_OP_NUM(93, *, int32_t, 0); + TEST_T_OP_NUM(93, *, uint32_t, 0); TEST_T_OP_NUM(93, *, float, 0); TEST_T_OP_NUM(93, *, double, 0); - TEST_T_OP_NUM(0, *, int32, 76); - TEST_T_OP_NUM(0, *, uint32, 76); + TEST_T_OP_NUM(0, *, int32_t, 76); + TEST_T_OP_NUM(0, *, uint32_t, 76); TEST_T_OP_NUM(0, *, float, 76); TEST_T_OP_NUM(0, *, double, 76); @@ -279,14 +279,14 @@ TYPED_TEST(SignNeutralSafeIntTest, TestDivide) { typedef typename T::ValueType V; // Test positive vs. positive division across types. - TEST_T_OP_NUM(9, /, int32, 3); - TEST_T_OP_NUM(9, /, uint32, 3); + TEST_T_OP_NUM(9, /, int32_t, 3); + TEST_T_OP_NUM(9, /, uint32_t, 3); TEST_T_OP_NUM(9, /, float, 3); TEST_T_OP_NUM(9, /, double, 3); // Test zero vs. positive division across types. - TEST_T_OP_NUM(0, /, int32, 76); - TEST_T_OP_NUM(0, /, uint32, 76); + TEST_T_OP_NUM(0, /, int32_t, 76); + TEST_T_OP_NUM(0, /, uint32_t, 76); TEST_T_OP_NUM(0, /, float, 76); TEST_T_OP_NUM(0, /, double, 76); } @@ -307,12 +307,12 @@ TYPED_TEST(SignNeutralSafeIntTest, TestModulo) { typedef typename T::ValueType V; // Test positive vs. positive modulo across signedness. - TEST_T_OP_NUM(7, %, int32, 6); - TEST_T_OP_NUM(7, %, uint32, 6); + TEST_T_OP_NUM(7, %, int32_t, 6); + TEST_T_OP_NUM(7, %, uint32_t, 6); // Test zero vs. positive modulo across signedness. - TEST_T_OP_NUM(0, %, int32, 6); - TEST_T_OP_NUM(0, %, uint32, 6); + TEST_T_OP_NUM(0, %, int32_t, 6); + TEST_T_OP_NUM(0, %, uint32_t, 6); } TYPED_TEST(SignNeutralSafeIntTest, TestModuloFailures) { @@ -534,28 +534,28 @@ TYPED_TEST(SignedSafeIntTest, TestMultiply) { typedef typename T::ValueType V; // Test negative vs. positive multiplication across types. - TEST_T_OP_NUM(-9, *, int32, 3); - TEST_T_OP_NUM(-9, *, uint32, 3); + TEST_T_OP_NUM(-9, *, int32_t, 3); + TEST_T_OP_NUM(-9, *, uint32_t, 3); TEST_T_OP_NUM(-9, *, float, 3); TEST_T_OP_NUM(-9, *, double, 3); // Test positive vs. negative multiplication across types. - TEST_T_OP_NUM(9, *, int32, -3); + TEST_T_OP_NUM(9, *, int32_t, -3); // Don't cover unsigneds that are initialized from negative values. TEST_T_OP_NUM(9, *, float, -3); TEST_T_OP_NUM(9, *, double, -3); // Test negative vs. negative multiplication across types. - TEST_T_OP_NUM(-9, *, int32, -3); + TEST_T_OP_NUM(-9, *, int32_t, -3); // Don't cover unsigneds that are initialized from negative values. TEST_T_OP_NUM(-9, *, float, -3); TEST_T_OP_NUM(-9, *, double, -3); // Test negative vs. zero multiplication commutatively across types. - TEST_T_OP_NUM(-93, *, int32, 0); - TEST_T_OP_NUM(-93, *, uint32, 0); + TEST_T_OP_NUM(-93, *, int32_t, 0); + TEST_T_OP_NUM(-93, *, uint32_t, 0); TEST_T_OP_NUM(-93, *, float, 0); TEST_T_OP_NUM(-93, *, double, 0); - TEST_T_OP_NUM(0, *, int32, -76); - TEST_T_OP_NUM(0, *, uint32, -76); + TEST_T_OP_NUM(0, *, int32_t, -76); + TEST_T_OP_NUM(0, *, uint32_t, -76); TEST_T_OP_NUM(0, *, float, -76); TEST_T_OP_NUM(0, *, double, -76); @@ -600,24 +600,24 @@ TYPED_TEST(SignedSafeIntTest, TestDivide) { typedef typename T::ValueType V; // Test negative vs. positive division across types. - TEST_T_OP_NUM(-9, /, int32, 3); - TEST_T_OP_NUM(-9, /, uint32, 3); + TEST_T_OP_NUM(-9, /, int32_t, 3); + TEST_T_OP_NUM(-9, /, uint32_t, 3); TEST_T_OP_NUM(-9, /, float, 3); TEST_T_OP_NUM(-9, /, double, 3); // Test positive vs. negative division across types. - TEST_T_OP_NUM(9, /, int32, -3); - TEST_T_OP_NUM(9, /, uint32, -3); + TEST_T_OP_NUM(9, /, int32_t, -3); + TEST_T_OP_NUM(9, /, uint32_t, -3); TEST_T_OP_NUM(9, /, float, -3); TEST_T_OP_NUM(9, /, double, -3); // Test negative vs. negative division across types. - TEST_T_OP_NUM(-9, /, int32, -3); - TEST_T_OP_NUM(-9, /, uint32, -3); + TEST_T_OP_NUM(-9, /, int32_t, -3); + TEST_T_OP_NUM(-9, /, uint32_t, -3); TEST_T_OP_NUM(-9, /, float, -3); TEST_T_OP_NUM(-9, /, double, -3); // Test zero vs. negative division across types. - TEST_T_OP_NUM(0, /, int32, -76); - TEST_T_OP_NUM(0, /, uint32, -76); + TEST_T_OP_NUM(0, /, int32_t, -76); + TEST_T_OP_NUM(0, /, uint32_t, -76); TEST_T_OP_NUM(0, /, float, -76); TEST_T_OP_NUM(0, /, double, -76); } @@ -638,18 +638,18 @@ TYPED_TEST(SignedSafeIntTest, TestModulo) { typedef typename T::ValueType V; // Test negative vs. positive modulo across signedness. - TEST_T_OP_NUM(-7, %, int32, 6); - TEST_T_OP_NUM(-7, %, uint32, 6); + TEST_T_OP_NUM(-7, %, int32_t, 6); + TEST_T_OP_NUM(-7, %, uint32_t, 6); // Test positive vs. negative modulo across signedness. - TEST_T_OP_NUM(7, %, int32, -6); - TEST_T_OP_NUM(7, %, uint32, -6); + TEST_T_OP_NUM(7, %, int32_t, -6); + TEST_T_OP_NUM(7, %, uint32_t, -6); // Test negative vs. negative modulo across signedness. - TEST_T_OP_NUM(-7, %, int32, -6); - TEST_T_OP_NUM(-7, %, uint32, -6); + TEST_T_OP_NUM(-7, %, int32_t, -6); + TEST_T_OP_NUM(-7, %, uint32_t, -6); // Test zero vs. negative modulo across signedness. - TEST_T_OP_NUM(0, %, int32, -6); - TEST_T_OP_NUM(0, %, uint32, -6); + TEST_T_OP_NUM(0, %, int32_t, -6); + TEST_T_OP_NUM(0, %, uint32_t, -6); } TYPED_TEST(SignedSafeIntTest, TestModuloFailures) { From 7417e48da42948ff142acc08d02a2643e8831949 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 00:03:00 -0700 Subject: [PATCH 44/63] Internal change PiperOrigin-RevId: 521970274 --- .../autoflip/calculators/border_detection_calculator.cc | 2 +- .../autoflip/calculators/content_zooming_calculator.cc | 8 ++++---- .../calculators/content_zooming_calculator_test.cc | 4 ++-- .../autoflip/calculators/scene_cropping_calculator.cc | 6 +++--- .../calculators/scene_cropping_calculator_test.cc | 8 ++++---- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc index caaa368a7..238bcf8be 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc @@ -214,7 +214,7 @@ double BorderDetectionCalculator::ColorCount(const Color& mask_color, const cv::Mat& image) const { int background_count = 0; for (int i = 0; i < image.rows; i++) { - const uint8* row_ptr = image.ptr(i); + const uint8_t* row_ptr = image.ptr(i); for (int j = 0; j < image.cols * 3; j += 3) { if (std::abs(mask_color.r() - static_cast(row_ptr[j + 2])) <= options_.color_tolerance() && diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index 823080786..5241f56e4 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -142,7 +142,7 @@ class ContentZoomingCalculator : public CalculatorBase { // Stores the first crop rectangle. mediapipe::NormalizedRect first_rect_; // Stores the time of the last "only_required" input. - int64 last_only_required_detection_; + int64_t last_only_required_detection_; // Rect values of last message with detection(s). int last_measured_height_; int last_measured_x_offset_; @@ -500,7 +500,7 @@ bool ContentZoomingCalculator::IsAnimatingToFirstRect( return false; } - const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); + const int64_t delta_us = (timestamp - first_rect_timestamp_).Value(); return (0 <= delta_us && delta_us <= options_.us_to_first_rect()); } @@ -522,8 +522,8 @@ absl::StatusOr ContentZoomingCalculator::GetAnimationRect( RET_CHECK(IsAnimatingToFirstRect(timestamp)) << "Must only be called if animating to first rect."; - const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); - const int64 delay = options_.us_to_first_rect_delay(); + const int64_t delta_us = (timestamp - first_rect_timestamp_).Value(); + const int64_t delay = options_.us_to_first_rect_delay(); const double interpolation = easeInOutQuad(std::max( 0.0, (delta_us - delay) / static_cast(options_.us_to_first_rect() - delay))); diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index 48e4a28a8..0e817b260 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -226,7 +226,7 @@ struct AddDetectionFlags { std::optional max_zoom_factor_percent; }; -void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, +void AddDetectionFrameSize(const cv::Rect_& position, const int64_t time, const int width, const int height, CalculatorRunner* runner, const AddDetectionFlags& flags = {}) { @@ -275,7 +275,7 @@ void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, } } -void AddDetection(const cv::Rect_& position, const int64 time, +void AddDetection(const cv::Rect_& position, const int64_t time, CalculatorRunner* runner) { AddDetectionFrameSize(position, time, 1000, 1000, runner); } diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index 7e286b743..f4cc98674 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -200,7 +200,7 @@ absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, } void ConstructExternalRenderMessage( const cv::Rect& crop_from_location, const cv::Rect& render_to_location, - const cv::Scalar& padding_color, const uint64 timestamp_us, + const cv::Scalar& padding_color, const uint64_t timestamp_us, ExternalRenderFrame* external_render_message, int frame_width, int frame_height) { auto crop_from_message = @@ -717,7 +717,7 @@ absl::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( for (int i = 0; i < num_frames; ++i) { // Set default padding color to white. cv::Scalar padding_color_to_add = cv::Scalar(255, 255, 255); - const int64 time_ms = scene_frame_timestamps_[i]; + const int64_t time_ms = scene_frame_timestamps_[i]; if (*apply_padding) { if (has_solid_background_) { double lab[3]; @@ -747,7 +747,7 @@ absl::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( // Resizes cropped frames, pads frames, and output frames. for (int i = 0; i < num_frames; ++i) { - const int64 time_ms = scene_frame_timestamps_[i]; + const int64_t time_ms = scene_frame_timestamps_[i]; const Timestamp timestamp(time_ms); auto scaled_frame = absl::make_unique( frame_format_, scaled_width, scaled_height); diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index c3285ea58..74535022d 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -175,7 +175,7 @@ constexpr int kMinNumDetections = 0; constexpr int kMaxNumDetections = 10; constexpr int kDownSampleRate = 4; -constexpr int64 kTimestampDiff = 20000; +constexpr int64_t kTimestampDiff = 20000; // Returns a singleton random engine for generating random values. The seed is // fixed for reproducibility. @@ -254,7 +254,7 @@ std::unique_ptr MakeImageFrameFromColor(const cv::Scalar& color, // Randomly generates a number of detections in the range of kMinNumDetections // and kMaxNumDetections. Optionally add a key image frame of random solid color // and given size. -void AddKeyFrameFeatures(const int64 time_ms, const int key_frame_width, +void AddKeyFrameFeatures(const int64_t time_ms, const int key_frame_width, const int key_frame_height, bool randomize, CalculatorRunner::StreamContentsSet* inputs) { Timestamp timestamp(time_ms); @@ -286,7 +286,7 @@ void AddScene(const int start_frame_index, const int num_scene_frames, const int key_frame_width, const int key_frame_height, const int DownSampleRate, CalculatorRunner::StreamContentsSet* inputs) { - int64 time_ms = start_frame_index * kTimestampDiff; + int64_t time_ms = start_frame_index * kTimestampDiff; for (int i = 0; i < num_scene_frames; ++i) { Timestamp timestamp(time_ms); if (inputs->HasTag(kVideoFramesTag)) { @@ -657,7 +657,7 @@ TEST(SceneCroppingCalculatorTest, PadsWithSolidColorFromStaticFeatures) { // Add inputs. auto* inputs = runner->MutableInputs(); - int64 time_ms = 0; + int64_t time_ms = 0; int num_static_features = 0; for (int i = 0; i < kSceneSize; ++i) { Timestamp timestamp(time_ms); From 46f92707880c78f45ea9c855d55edabbaa5e07d6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 01:07:18 -0700 Subject: [PATCH 45/63] Internal change PiperOrigin-RevId: 521980958 --- mediapipe/framework/formats/image_frame.cc | 73 ++++++++++--------- .../formats/image_frame_opencv_test.cc | 12 +-- mediapipe/framework/formats/image_opencv.cc | 2 +- .../framework/formats/location_opencv.cc | 4 +- .../framework/formats/location_opencv_test.cc | 12 +-- 5 files changed, 52 insertions(+), 51 deletions(-) diff --git a/mediapipe/framework/formats/image_frame.cc b/mediapipe/framework/formats/image_frame.cc index 913ffae24..772c91014 100644 --- a/mediapipe/framework/formats/image_frame.cc +++ b/mediapipe/framework/formats/image_frame.cc @@ -33,7 +33,7 @@ namespace mediapipe { namespace { -int CountOnes(uint32 n) { +int CountOnes(uint32_t n) { #if (defined(__i386__) || defined(__x86_64__)) && defined(__POPCNT__) && \ defined(__GNUC__) return __builtin_popcount(n); @@ -47,20 +47,21 @@ int CountOnes(uint32 n) { } // namespace const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kArrayDelete = - std::default_delete(); + std::default_delete(); const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kFree = free; const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kAlignedFree = aligned_free; -const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kNone = [](uint8* x) {}; +const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kNone = [](uint8_t* x) { +}; -const uint32 ImageFrame::kDefaultAlignmentBoundary; -const uint32 ImageFrame::kGlDefaultAlignmentBoundary; +const uint32_t ImageFrame::kDefaultAlignmentBoundary; +const uint32_t ImageFrame::kGlDefaultAlignmentBoundary; ImageFrame::ImageFrame() : format_(ImageFormat::UNKNOWN), width_(0), height_(0), width_step_(0) {} ImageFrame::ImageFrame(ImageFormat::Format format, int width, int height, - uint32 alignment_boundary) + uint32_t alignment_boundary) : format_(format), width_(width), height_(height) { Reset(format, width, height, alignment_boundary); } @@ -71,7 +72,7 @@ ImageFrame::ImageFrame(ImageFormat::Format format, int width, int height) } ImageFrame::ImageFrame(ImageFormat::Format format, int width, int height, - int width_step, uint8* pixel_data, + int width_step, uint8_t* pixel_data, ImageFrame::Deleter deleter) { AdoptPixelData(format, width, height, width_step, pixel_data, deleter); } @@ -93,7 +94,7 @@ ImageFrame& ImageFrame::operator=(ImageFrame&& move_from) { } void ImageFrame::Reset(ImageFormat::Format format, int width, int height, - uint32 alignment_boundary) { + uint32_t alignment_boundary) { format_ = format; width_ = width; height_ = height; @@ -101,7 +102,7 @@ void ImageFrame::Reset(ImageFormat::Format format, int width, int height, CHECK(IsValidAlignmentNumber(alignment_boundary)); width_step_ = width * NumberOfChannels() * ByteDepth(); if (alignment_boundary == 1) { - pixel_data_ = {new uint8[height * width_step_], + pixel_data_ = {new uint8_t[height * width_step_], PixelDataDeleter::kArrayDelete}; } else { // Increase width_step_ to the smallest multiple of alignment_boundary @@ -109,14 +110,14 @@ void ImageFrame::Reset(ImageFormat::Format format, int width, int height, // twiddling bits. alignment_boundary - 1 is a mask which sets all // the low order bits. width_step_ = ((width_step_ - 1) | (alignment_boundary - 1)) + 1; - pixel_data_ = {reinterpret_cast(aligned_malloc(height * width_step_, - alignment_boundary)), + pixel_data_ = {reinterpret_cast(aligned_malloc( + height * width_step_, alignment_boundary)), PixelDataDeleter::kAlignedFree}; } } void ImageFrame::AdoptPixelData(ImageFormat::Format format, int width, - int height, int width_step, uint8* pixel_data, + int height, int width_step, uint8_t* pixel_data, ImageFrame::Deleter deleter) { format_ = format; width_ = width; @@ -129,12 +130,12 @@ void ImageFrame::AdoptPixelData(ImageFormat::Format format, int width, pixel_data_ = {pixel_data, deleter}; } -std::unique_ptr ImageFrame::Release() { +std::unique_ptr ImageFrame::Release() { return std::move(pixel_data_); } void ImageFrame::InternalCopyFrom(int width, int height, int width_step, - int channel_size, const uint8* pixel_data) { + int channel_size, const uint8_t* pixel_data) { CHECK_EQ(width_, width); CHECK_EQ(height_, height); // row_bytes = channel_size * num_channels * width @@ -192,9 +193,9 @@ void ImageFrame::SetAlignmentPaddingAreas() { const int pixel_size = ByteDepth() * NumberOfChannels(); const int padding_size = width_step_ - width_ * pixel_size; for (int row = 0; row < height_; ++row) { - uint8* row_start = pixel_data_.get() + width_step_ * row; - uint8* last_pixel_in_row = row_start + (width_ - 1) * pixel_size; - uint8* padding = row_start + width_ * pixel_size; + uint8_t* row_start = pixel_data_.get() + width_step_ * row; + uint8_t* last_pixel_in_row = row_start + (width_ - 1) * pixel_size; + uint8_t* padding = row_start + width_ * pixel_size; int padding_index = 0; while (padding_index + pixel_size - 1 < padding_size) { // Copy the entire last pixel in the row into this padding pixel. @@ -220,7 +221,7 @@ bool ImageFrame::IsContiguous() const { return width_step_ == width_ * NumberOfChannels() * ByteDepth(); } -bool ImageFrame::IsAligned(uint32 alignment_boundary) const { +bool ImageFrame::IsAligned(uint32_t alignment_boundary) const { CHECK(IsValidAlignmentNumber(alignment_boundary)); if (!pixel_data_) { return false; @@ -236,7 +237,7 @@ bool ImageFrame::IsAligned(uint32 alignment_boundary) const { } // static -bool ImageFrame::IsValidAlignmentNumber(uint32 alignment_boundary) { +bool ImageFrame::IsValidAlignmentNumber(uint32_t alignment_boundary) { return CountOnes(alignment_boundary) == 1; } @@ -293,25 +294,25 @@ int ImageFrame::ChannelSize() const { return ChannelSizeForFormat(format_); } int ImageFrame::ChannelSizeForFormat(ImageFormat::Format format) { switch (format) { case ImageFormat::GRAY8: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::SRGB: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::SRGBA: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::GRAY16: - return sizeof(uint16); + return sizeof(uint16_t); case ImageFormat::SRGB48: - return sizeof(uint16); + return sizeof(uint16_t); case ImageFormat::SRGBA64: - return sizeof(uint16); + return sizeof(uint16_t); case ImageFormat::VEC32F1: return sizeof(float); case ImageFormat::VEC32F2: return sizeof(float); case ImageFormat::LAB8: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::SBGRA: - return sizeof(uint8); + return sizeof(uint8_t); default: LOG(FATAL) << InvalidFormatString(format); } @@ -347,7 +348,7 @@ int ImageFrame::ByteDepthForFormat(ImageFormat::Format format) { } void ImageFrame::CopyFrom(const ImageFrame& image_frame, - uint32 alignment_boundary) { + uint32_t alignment_boundary) { // Reset the current image. Reset(image_frame.Format(), image_frame.Width(), image_frame.Height(), alignment_boundary); @@ -359,29 +360,29 @@ void ImageFrame::CopyFrom(const ImageFrame& image_frame, } void ImageFrame::CopyPixelData(ImageFormat::Format format, int width, - int height, const uint8* pixel_data, - uint32 alignment_boundary) { + int height, const uint8_t* pixel_data, + uint32_t alignment_boundary) { CopyPixelData(format, width, height, 0 /* contiguous storage */, pixel_data, alignment_boundary); } void ImageFrame::CopyPixelData(ImageFormat::Format format, int width, int height, int width_step, - const uint8* pixel_data, - uint32 alignment_boundary) { + const uint8_t* pixel_data, + uint32_t alignment_boundary) { Reset(format, width, height, alignment_boundary); InternalCopyFrom(width, height, width_step, ChannelSizeForFormat(format), pixel_data); } -void ImageFrame::CopyToBuffer(uint8* buffer, int buffer_size) const { +void ImageFrame::CopyToBuffer(uint8_t* buffer, int buffer_size) const { CHECK(buffer); CHECK_EQ(1, ByteDepth()); const int data_size = width_ * height_ * NumberOfChannels(); CHECK_LE(data_size, buffer_size); if (IsContiguous()) { // The data is stored contiguously, we can just copy. - const uint8* src = reinterpret_cast(pixel_data_.get()); + const uint8_t* src = reinterpret_cast(pixel_data_.get()); std::copy_n(src, data_size, buffer); } else { InternalCopyToBuffer(0 /* contiguous storage */, @@ -389,14 +390,14 @@ void ImageFrame::CopyToBuffer(uint8* buffer, int buffer_size) const { } } -void ImageFrame::CopyToBuffer(uint16* buffer, int buffer_size) const { +void ImageFrame::CopyToBuffer(uint16_t* buffer, int buffer_size) const { CHECK(buffer); CHECK_EQ(2, ByteDepth()); const int data_size = width_ * height_ * NumberOfChannels(); CHECK_LE(data_size, buffer_size); if (IsContiguous()) { // The data is stored contiguously, we can just copy. - const uint16* src = reinterpret_cast(pixel_data_.get()); + const uint16_t* src = reinterpret_cast(pixel_data_.get()); std::copy_n(src, data_size, buffer); } else { InternalCopyToBuffer(0 /* contiguous storage */, diff --git a/mediapipe/framework/formats/image_frame_opencv_test.cc b/mediapipe/framework/formats/image_frame_opencv_test.cc index f75915d06..ae6f90f81 100644 --- a/mediapipe/framework/formats/image_frame_opencv_test.cc +++ b/mediapipe/framework/formats/image_frame_opencv_test.cc @@ -51,8 +51,8 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { // Check adding constant images. const uint8_t frame1_val = 12; const uint8_t frame2_val = 34; - SetToColor(&frame1_val, &frame1); - SetToColor(&frame2_val, &frame2); + SetToColor(&frame1_val, &frame1); + SetToColor(&frame2_val, &frame2); // Get Mat wrapper around ImageFrame memory (zero copy). cv::Mat frame1_mat = formats::MatView(&frame1); cv::Mat frame2_mat = formats::MatView(&frame2); @@ -62,7 +62,7 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { EXPECT_EQ(frame_avg, frame1_val + frame2_val); // Check setting min/max pixels. - uint8* frame1_ptr = frame1.MutablePixelData(); + uint8_t* frame1_ptr = frame1.MutablePixelData(); frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1; frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100; double min, max; @@ -84,8 +84,8 @@ TEST(ImageFrameOpencvTest, ConvertToIpl) { // Check adding constant images. const uint8_t frame1_val = 12; const uint8_t frame2_val = 34; - SetToColor(&frame1_val, &frame1); - SetToColor(&frame2_val, &frame2); + SetToColor(&frame1_val, &frame1); + SetToColor(&frame2_val, &frame2); const cv::Mat frame1_mat = formats::MatView(&frame1); const cv::Mat frame2_mat = formats::MatView(&frame2); const cv::Mat frame_sum = frame1_mat + frame2_mat; @@ -93,7 +93,7 @@ TEST(ImageFrameOpencvTest, ConvertToIpl) { EXPECT_EQ(frame_avg, frame1_val + frame2_val); // Check setting min/max pixels. - uint8* frame1_ptr = frame1.MutablePixelData(); + uint8_t* frame1_ptr = frame1.MutablePixelData(); frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1; frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100; double min, max; diff --git a/mediapipe/framework/formats/image_opencv.cc b/mediapipe/framework/formats/image_opencv.cc index 9ccaa632b..7d9ce4a13 100644 --- a/mediapipe/framework/formats/image_opencv.cc +++ b/mediapipe/framework/formats/image_opencv.cc @@ -96,7 +96,7 @@ std::shared_ptr MatView(const mediapipe::Image* image) { image->image_format()))}; auto owner = std::make_shared(const_cast(image)); - uint8* data_ptr = owner->lock.Pixels(); + uint8_t* data_ptr = owner->lock.Pixels(); CHECK(data_ptr != nullptr); // Use Image to initialize in-place. Image still owns memory. if (steps[0] == sizes[1] * image->channels() * diff --git a/mediapipe/framework/formats/location_opencv.cc b/mediapipe/framework/formats/location_opencv.cc index de59633ca..6e15b299a 100644 --- a/mediapipe/framework/formats/location_opencv.cc +++ b/mediapipe/framework/formats/location_opencv.cc @@ -91,7 +91,7 @@ std::unique_ptr GetCvMask(const Location& location) { new cv::Mat(mask.height(), mask.width(), CV_8UC1, cv::Scalar(0))); for (const auto& interval : location_data.mask().rasterization().interval()) { for (int x = interval.left_x(); x <= interval.right_x(); ++x) { - mat->at(interval.y(), x) = 255; + mat->at(interval.y(), x) = 255; } } return mat; @@ -174,7 +174,7 @@ void EnlargeLocation(Location& location, const float factor) { } else { cv::erode(*mask, *mask, morph_element); } - CreateCvMaskLocation(*mask).ConvertToProto(&location_data); + CreateCvMaskLocation(*mask).ConvertToProto(&location_data); break; } } diff --git a/mediapipe/framework/formats/location_opencv_test.cc b/mediapipe/framework/formats/location_opencv_test.cc index 5740d2b17..6e3a89b58 100644 --- a/mediapipe/framework/formats/location_opencv_test.cc +++ b/mediapipe/framework/formats/location_opencv_test.cc @@ -25,8 +25,8 @@ namespace mediapipe { // segments per row. static const int kWidth = 7; static const int kHeight = 3; -const std::vector kTestPatternVector = {0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, - 0, 0, 0, 1, 0, 1, 0, 1, 0, 0}; +const std::vector kTestPatternVector = { + 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0}; // Interval {y, x_start, x_end} representation of kTestPatternVector. const std::vector> kTestPatternIntervals = { @@ -67,8 +67,8 @@ TEST(LocationOpencvTest, CreateBBoxLocation) { } TEST(LocationOpencvTest, CreateCvMaskLocation) { - cv::Mat_ test_mask(kHeight, kWidth, - const_cast(kTestPatternVector.data())); + cv::Mat_ test_mask(kHeight, kWidth, + const_cast(kTestPatternVector.data())); Location location = CreateCvMaskLocation(test_mask); auto intervals = location.ConvertToProto().mask().rasterization().interval(); EXPECT_EQ(intervals.size(), kTestPatternIntervals.size()); @@ -157,8 +157,8 @@ TEST(LocationOpenCvTest, GetCvMask) { auto cv_mask = *GetCvMask(test_location); EXPECT_EQ(cv_mask.cols * cv_mask.rows, kTestPatternVector.size()); int flat_idx = 0; - for (auto it = cv_mask.begin(); it != cv_mask.end(); ++it) { - const uint8 expected_value = kTestPatternVector[flat_idx] == 0 ? 0 : 255; + for (auto it = cv_mask.begin(); it != cv_mask.end(); ++it) { + const uint8_t expected_value = kTestPatternVector[flat_idx] == 0 ? 0 : 255; EXPECT_EQ(*it, expected_value); flat_idx++; } From 05801b99453159ec2cf76b0a13858c25ed9e2f3c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 01:09:59 -0700 Subject: [PATCH 46/63] Internal change PiperOrigin-RevId: 521981387 --- .../java/com/google/mediapipe/framework/jni/graph.cc | 2 +- .../google/mediapipe/framework/jni/packet_creator_jni.cc | 3 ++- .../google/mediapipe/framework/jni/packet_getter_jni.cc | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 23bd553af..d565187d9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -578,7 +578,7 @@ mediapipe::GpuResources* Graph::GetGpuResources() const { } #endif // !MEDIAPIPE_DISABLE_GPU -absl::Status Graph::SetParentGlContext(int64 java_gl_context) { +absl::Status Graph::SetParentGlContext(int64_t java_gl_context) { #if MEDIAPIPE_DISABLE_GPU LOG(FATAL) << "GPU support has been disabled in this build!"; #else diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 46ea1ce41..f7430e6e8 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -132,7 +132,8 @@ CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, // code might expect to be able to overwrite the buffer after creating an // ImageFrame from it. image_frame->CopyPixelData( - format, width, height, width_step, static_cast(buffer_data), + format, width, height, width_step, + static_cast(buffer_data), mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); return image_frame; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index d5bd773f3..cc273bca4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -65,12 +65,12 @@ bool CopyImageDataToByteBuffer(JNIEnv* env, const mediapipe::ImageFrame& image, switch (image.ByteDepth()) { case 1: { - uint8* data = static_cast(buffer_data); + uint8_t* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 2: { - uint16* data = static_cast(buffer_data); + uint16_t* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } @@ -503,8 +503,8 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetAudioData)( int offset = 0; for (int sample = 0; sample < num_samples; ++sample) { for (int channel = 0; channel < num_channels; ++channel) { - int16 value = - static_cast(audio_mat(channel, sample) * kMultiplier); + int16_t value = + static_cast(audio_mat(channel, sample) * kMultiplier); // The java and native has the same byte order, by default is little // Endian, we can safely copy data directly, we have tests to cover // this. From 91264eab1f0a8e6d142d9c0fdcec1dbe07155e3c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 01:14:32 -0700 Subject: [PATCH 47/63] Internal change PiperOrigin-RevId: 521982139 --- .../audio/basic_time_series_calculators.cc | 7 ++++--- .../calculators/audio/spectrogram_calculator.cc | 6 +++--- .../audio/spectrogram_calculator_test.cc | 14 +++++++------- .../audio/stabilized_log_calculator_test.cc | 6 ++++-- .../audio/time_series_framer_calculator.cc | 10 +++++----- .../audio/time_series_framer_calculator_test.cc | 6 +++--- 6 files changed, 26 insertions(+), 23 deletions(-) diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.cc b/mediapipe/calculators/audio/basic_time_series_calculators.cc index f7b24f6f6..5006a0b54 100644 --- a/mediapipe/calculators/audio/basic_time_series_calculators.cc +++ b/mediapipe/calculators/audio/basic_time_series_calculators.cc @@ -26,10 +26,11 @@ namespace mediapipe { namespace { static bool SafeMultiply(int x, int y, int* result) { - static_assert(sizeof(int64) >= 2 * sizeof(int), + static_assert(sizeof(int64_t) >= 2 * sizeof(int), "Unable to detect overflow after multiplication"); - const int64 big = static_cast(x) * static_cast(y); - if (big > static_cast(INT_MIN) && big < static_cast(INT_MAX)) { + const int64_t big = static_cast(x) * static_cast(y); + if (big > static_cast(INT_MIN) && + big < static_cast(INT_MAX)) { if (result != nullptr) *result = static_cast(big); return true; } else { diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index bd4d8f3bf..939e721ab 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -182,12 +182,12 @@ class SpectrogramCalculator : public CalculatorBase { int frame_duration_samples_; int frame_overlap_samples_; // How many samples we've been passed, used for checking input time stamps. - int64 cumulative_input_samples_; + int64_t cumulative_input_samples_; // How many frames we've emitted, used for calculating output time stamps. - int64 cumulative_completed_frames_; + int64_t cumulative_completed_frames_; // How many frames were emitted last, used for estimating the timestamp on // Close when use_local_timestamp_ is true; - int64 last_completed_frames_; + int64_t last_completed_frames_; Timestamp initial_input_timestamp_; int num_input_channels_; // How many frequency bins we emit (=N_FFT/2 + 1). diff --git a/mediapipe/calculators/audio/spectrogram_calculator_test.cc b/mediapipe/calculators/audio/spectrogram_calculator_test.cc index 3c2b8435d..b35f30583 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator_test.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator_test.cc @@ -92,8 +92,8 @@ class SpectrogramCalculatorTest .cos() .transpose(); } - int64 input_timestamp = round(packet_start_time_seconds * - Timestamp::kTimestampUnitsPerSecond); + int64_t input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); AppendInputPacket(packet_data, input_timestamp); total_num_input_samples += packet_size_samples; } @@ -116,8 +116,8 @@ class SpectrogramCalculatorTest double packet_start_time_seconds = kInitialTimestampOffsetMicroseconds * 1e-6 + total_num_input_samples / input_sample_rate_; - int64 input_timestamp = round(packet_start_time_seconds * - Timestamp::kTimestampUnitsPerSecond); + int64_t input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); std::unique_ptr impulse( new Matrix(Matrix::Zero(1, packet_sizes_samples[i]))); (*impulse)(0, impulse_offsets_samples[i]) = 1.0; @@ -157,8 +157,8 @@ class SpectrogramCalculatorTest .cos() .transpose(); } - int64 input_timestamp = round(packet_start_time_seconds * - Timestamp::kTimestampUnitsPerSecond); + int64_t input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); AppendInputPacket(packet_data, input_timestamp); total_num_input_samples += packet_size_samples; } @@ -218,7 +218,7 @@ class SpectrogramCalculatorTest const double expected_timestamp_seconds = packet_timestamp_offset_seconds + cumulative_output_frames * frame_step_seconds; - const int64 expected_timestamp_ticks = + const int64_t expected_timestamp_ticks = expected_timestamp_seconds * Timestamp::kTimestampUnitsPerSecond; EXPECT_EQ(expected_timestamp_ticks, packet.Timestamp().Value()); // Accept the timestamp of the first packet as the baseline for checking diff --git a/mediapipe/calculators/audio/stabilized_log_calculator_test.cc b/mediapipe/calculators/audio/stabilized_log_calculator_test.cc index e6e0b5c6f..f04202676 100644 --- a/mediapipe/calculators/audio/stabilized_log_calculator_test.cc +++ b/mediapipe/calculators/audio/stabilized_log_calculator_test.cc @@ -54,7 +54,8 @@ TEST_F(StabilizedLogCalculatorTest, BasicOperation) { std::vector input_data_matrices; for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) { - const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond; + const int64_t timestamp = + input_packet * Timestamp::kTimestampUnitsPerSecond; Matrix input_data_matrix = Matrix::Random(kNumChannels, kNumSamples).array().abs(); input_data_matrices.push_back(input_data_matrix); @@ -80,7 +81,8 @@ TEST_F(StabilizedLogCalculatorTest, OutputScaleWorks) { std::vector input_data_matrices; for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) { - const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond; + const int64_t timestamp = + input_packet * Timestamp::kTimestampUnitsPerSecond; Matrix input_data_matrix = Matrix::Random(kNumChannels, kNumSamples).array().abs(); input_data_matrices.push_back(input_data_matrix); diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.cc b/mediapipe/calculators/audio/time_series_framer_calculator.cc index fbbf34226..a200b898a 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator.cc @@ -109,7 +109,7 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // Returns the timestamp of a sample on a base, which is usually the time // stamp of a packet. Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base, - int64 number_of_samples) { + int64_t number_of_samples) { return timestamp_base + round(number_of_samples / sample_rate_ * Timestamp::kTimestampUnitsPerSecond); } @@ -118,10 +118,10 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // emitted. int next_frame_step_samples() const { // All numbers are in input samples. - const int64 current_output_frame_start = static_cast( + const int64_t current_output_frame_start = static_cast( round(cumulative_output_frames_ * average_frame_step_samples_)); CHECK_EQ(current_output_frame_start, cumulative_completed_samples_); - const int64 next_output_frame_start = static_cast( + const int64_t next_output_frame_start = static_cast( round((cumulative_output_frames_ + 1) * average_frame_step_samples_)); return next_output_frame_start - current_output_frame_start; } @@ -134,11 +134,11 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // emulate_fractional_frame_overlap is true. double average_frame_step_samples_; int samples_still_to_drop_; - int64 cumulative_output_frames_; + int64_t cumulative_output_frames_; // "Completed" samples are samples that are no longer needed because // the framer has completely stepped past them (taking into account // any overlap). - int64 cumulative_completed_samples_; + int64_t cumulative_completed_samples_; Timestamp initial_input_timestamp_; // The current timestamp is updated along with the incoming packets. Timestamp current_timestamp_; diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc index ca88cebb5..72e9c88f7 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc @@ -49,7 +49,7 @@ class TimeSeriesFramerCalculatorTest // Returns a float value with the channel and timestamp separated by // an order of magnitude, for easy parsing by humans. - float TestValue(int64 timestamp_in_microseconds, int channel) { + float TestValue(int64_t timestamp_in_microseconds, int channel) { return timestamp_in_microseconds + channel / 10.0; } @@ -59,7 +59,7 @@ class TimeSeriesFramerCalculatorTest auto matrix = new Matrix(num_channels, num_samples); for (int c = 0; c < num_channels; ++c) { for (int i = 0; i < num_samples; ++i) { - int64 timestamp = time_series_util::SecondsToSamples( + int64_t timestamp = time_series_util::SecondsToSamples( starting_timestamp_seconds + i / input_sample_rate_, Timestamp::kTimestampUnitsPerSecond); (*matrix)(c, i) = TestValue(timestamp, c); @@ -429,7 +429,7 @@ class TimeSeriesFramerCalculatorTimestampingTest num_full_packets -= 1; } - int64 num_samples = 0; + int64_t num_samples = 0; for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) { const Packet& packet = output().packets[packet_num]; num_samples += FrameDurationSamples(); From 425a3ee3f69e2641407263524f99c5ce5ac51862 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 02:23:50 -0700 Subject: [PATCH 48/63] Internal change PiperOrigin-RevId: 521993439 --- mediapipe/util/sequence/media_sequence.cc | 30 +++--- .../util/sequence/media_sequence_test.cc | 92 +++++++++---------- .../util/sequence/media_sequence_util_test.cc | 34 +++---- 3 files changed, 78 insertions(+), 78 deletions(-) diff --git a/mediapipe/util/sequence/media_sequence.cc b/mediapipe/util/sequence/media_sequence.cc index f76c53295..287db6181 100644 --- a/mediapipe/util/sequence/media_sequence.cc +++ b/mediapipe/util/sequence/media_sequence.cc @@ -57,13 +57,13 @@ bool ImageMetadata(const std::string& image_str, std::string* format_string, // Finds the nearest timestamp in a FeatureList of timestamps. The FeatureList // must contain int64 values and only the first value at each step is used. -int NearestIndex(int64 timestamp, +int NearestIndex(int64_t timestamp, const tensorflow::FeatureList& int64_feature_list) { - int64 closest_distance = std::numeric_limits::max(); + int64_t closest_distance = std::numeric_limits::max(); int index = -1; for (int i = 0; i < int64_feature_list.feature_size(); ++i) { - int64 current_value = int64_feature_list.feature(i).int64_list().value(0); - int64 current_distance = std::abs(current_value - timestamp); + int64_t current_value = int64_feature_list.feature(i).int64_list().value(0); + int64_t current_distance = std::abs(current_value - timestamp); if (current_distance < closest_distance) { index = i; closest_distance = current_distance; @@ -74,8 +74,8 @@ int NearestIndex(int64 timestamp, // Find the numerical sampling rate between two values in seconds if the input // timestamps are in microseconds. -float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { - int64 timestamp_diff = second_timestamp - first_timestamp; +float TimestampsToRate(int64_t first_timestamp, int64_t second_timestamp) { + int64_t timestamp_diff = second_timestamp - first_timestamp; // convert from microseconds to seconds. float rate = 1.0 / (static_cast(timestamp_diff) / 1000000); return rate; @@ -100,18 +100,18 @@ absl::Status ReconcileAnnotationIndicesByImageTimestamps( << "start: " << segment_size << ", end: " << GetSegmentEndTimestampSize(*sequence); - std::vector start_indices; + std::vector start_indices; start_indices.reserve(segment_size); - for (const int64& timestamp : GetSegmentStartTimestamp(*sequence)) { + for (const int64_t& timestamp : GetSegmentStartTimestamp(*sequence)) { index = NearestIndex(timestamp, GetFeatureList(*sequence, kImageTimestampKey)); start_indices.push_back(index); } SetSegmentStartIndex(start_indices, sequence); - std::vector end_indices; + std::vector end_indices; end_indices.reserve(segment_size); - for (const int64& timestamp : GetSegmentEndTimestamp(*sequence)) { + for (const int64_t& timestamp : GetSegmentEndTimestamp(*sequence)) { index = NearestIndex(timestamp, GetFeatureList(*sequence, kImageTimestampKey)); end_indices.push_back(index); @@ -167,8 +167,8 @@ absl::Status ReconcileMetadataFeatureFloats( int number_of_elements = GetFeatureFloatsAt(prefix, *sequence, 0).size(); if (HasFeatureDimensions(prefix, *sequence) && !GetFeatureDimensions(prefix, *sequence).empty()) { - int64 product = 1; - for (int64 value : GetFeatureDimensions(prefix, *sequence)) { + int64_t product = 1; + for (int64_t value : GetFeatureDimensions(prefix, *sequence)) { product *= value; } RET_CHECK_EQ(number_of_elements, product) @@ -249,14 +249,14 @@ absl::Status ReconcileMetadataBoxAnnotations( // Collect which timestamps currently match to which indices in timestamps. // skip empty timestamps. // Requires sorted indices. - ::std::vector box_timestamps(num_bboxes); + ::std::vector box_timestamps(num_bboxes); int bbox_index = 0; std::string timestamp_key = merge_prefix(prefix, kRegionTimestampKey); for (auto& feature : GetFeatureList(*sequence, timestamp_key).feature()) { box_timestamps[bbox_index] = feature.int64_list().value(0); ++bbox_index; } - ::std::vector box_is_annotated(num_bboxes); + ::std::vector box_is_annotated(num_bboxes); bbox_index = 0; std::string is_annotated_key = merge_prefix(prefix, kRegionIsAnnotatedKey); for (auto& feature : @@ -264,7 +264,7 @@ absl::Status ReconcileMetadataBoxAnnotations( box_is_annotated[bbox_index] = feature.int64_list().value(0); ++bbox_index; } - ::std::vector image_timestamps(num_frames); + ::std::vector image_timestamps(num_frames); int frame_index = 0; for (auto& feature : GetFeatureList(*sequence, kImageTimestampKey).feature()) { diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 0797ed472..e220eace0 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -67,7 +67,7 @@ TEST(MediaSequenceTest, RoundTripEncodedMediaBytes) { TEST(MediaSequenceTest, RoundTripEncodedVideoStartTimestamp) { tensorflow::SequenceExample sequence; - int64 data = 47; + int64_t data = 47; SetClipEncodedMediaStartTimestamp(data, &sequence); ASSERT_EQ(GetClipEncodedMediaStartTimestamp(sequence), data); } @@ -92,7 +92,7 @@ TEST(MediaSequenceTest, RoundTripClipEndTimestamp) { TEST(MediaSequenceTest, RoundTripClipLabelIndex) { tensorflow::SequenceExample sequence; - std::vector label = {5, 3}; + std::vector label = {5, 3}; SetClipLabelIndex(label, &sequence); ASSERT_THAT(GetClipLabelIndex(sequence), testing::ElementsAreArray(label)); } @@ -115,46 +115,46 @@ TEST(MediaSequenceTest, RoundTripFloatListFrameRate) { TEST(MediaSequenceTest, RoundTripSegmentStartTimestamp) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentStartTimestampKey)); - SetSegmentStartTimestamp(::std::vector({123, 456}), &sequence); + SetSegmentStartTimestamp(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentStartTimestampSize(sequence)); ASSERT_THAT(GetSegmentStartTimestamp(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentEndTimestamp) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentEndTimestampKey)); - SetSegmentEndTimestamp(::std::vector({123, 456}), &sequence); + SetSegmentEndTimestamp(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentEndTimestampSize(sequence)); ASSERT_THAT(GetSegmentEndTimestamp(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentStartIndex) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentStartIndexKey)); - SetSegmentStartIndex(::std::vector({123, 456}), &sequence); + SetSegmentStartIndex(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentStartIndexSize(sequence)); ASSERT_THAT(GetSegmentStartIndex(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentEndIndex) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentEndIndexKey)); - SetSegmentEndIndex(::std::vector({123, 456}), &sequence); + SetSegmentEndIndex(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentEndIndexSize(sequence)); ASSERT_THAT(GetSegmentEndIndex(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentLabelIndex) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentLabelIndexKey)); - SetSegmentLabelIndex(::std::vector({5, 7}), &sequence); + SetSegmentLabelIndex(::std::vector({5, 7}), &sequence); ASSERT_EQ(2, GetSegmentLabelIndexSize(sequence)); ASSERT_THAT(GetSegmentLabelIndex(sequence), - testing::ElementsAreArray(::std::vector({5, 7}))); + testing::ElementsAreArray(::std::vector({5, 7}))); } TEST(MediaSequenceTest, RoundTripSegmentLabelString) { @@ -180,8 +180,8 @@ TEST(MediaSequenceTest, RoundTripSegmentLabelConfidence) { TEST(MediaSequenceTest, RoundTripImageWidthHeight) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; SetImageHeight(height, &sequence); ASSERT_EQ(GetImageHeight(sequence), height); SetImageWidth(width, &sequence); @@ -190,8 +190,8 @@ TEST(MediaSequenceTest, RoundTripImageWidthHeight) { TEST(MediaSequenceTest, RoundTripForwardFlowWidthHeight) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; SetForwardFlowHeight(height, &sequence); ASSERT_EQ(GetForwardFlowHeight(sequence), height); SetForwardFlowWidth(width, &sequence); @@ -200,8 +200,8 @@ TEST(MediaSequenceTest, RoundTripForwardFlowWidthHeight) { TEST(MediaSequenceTest, RoundTripClassSegmentationWidthHeightFormat) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; std::string format = "JPEG"; SetClassSegmentationHeight(height, &sequence); EXPECT_EQ(GetClassSegmentationHeight(sequence), height); @@ -213,7 +213,7 @@ TEST(MediaSequenceTest, RoundTripClassSegmentationWidthHeightFormat) { TEST(MediaSequenceTest, RoundTripClassSegmentationLabelIndex) { tensorflow::SequenceExample sequence; - std::vector classes = {5, 3}; + std::vector classes = {5, 3}; SetClassSegmentationClassLabelIndex(classes, &sequence); ASSERT_THAT(GetClassSegmentationClassLabelIndex(sequence), testing::ElementsAreArray({5, 3})); @@ -233,8 +233,8 @@ TEST(MediaSequenceTest, RoundTripClassSegmentationLabelString) { TEST(MediaSequenceTest, RoundTripInstanceSegmentationWidthHeightFormat) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; std::string format = "JPEG"; SetInstanceSegmentationHeight(height, &sequence); EXPECT_EQ(GetInstanceSegmentationHeight(sequence), height); @@ -246,7 +246,7 @@ TEST(MediaSequenceTest, RoundTripInstanceSegmentationWidthHeightFormat) { TEST(MediaSequenceTest, RoundTripInstanceSegmentationClass) { tensorflow::SequenceExample sequence; - std::vector classes = {5, 3}; + std::vector classes = {5, 3}; SetInstanceSegmentationObjectClassIndex(classes, &sequence); ASSERT_THAT(GetInstanceSegmentationObjectClassIndex(sequence), testing::ElementsAreArray({5, 3})); @@ -286,7 +286,7 @@ TEST(MediaSequenceTest, RoundTripBBoxNumRegions) { TEST(MediaSequenceTest, RoundTripBBoxLabelIndex) { tensorflow::SequenceExample sequence; - std::vector> labels = {{5, 3}, {1, 2}}; + std::vector> labels = {{5, 3}, {1, 2}}; for (int i = 0; i < labels.size(); ++i) { AddBBoxLabelIndex(labels[i], &sequence); ASSERT_EQ(GetBBoxLabelIndexSize(sequence), i + 1); @@ -312,7 +312,7 @@ TEST(MediaSequenceTest, RoundTripBBoxLabelString) { TEST(MediaSequenceTest, RoundTripBBoxClassIndex) { tensorflow::SequenceExample sequence; - std::vector> classes = {{5, 3}, {1, 2}}; + std::vector> classes = {{5, 3}, {1, 2}}; for (int i = 0; i < classes.size(); ++i) { AddBBoxClassIndex(classes[i], &sequence); ASSERT_EQ(GetBBoxClassIndexSize(sequence), i + 1); @@ -338,7 +338,7 @@ TEST(MediaSequenceTest, RoundTripBBoxClassString) { TEST(MediaSequenceTest, RoundTripBBoxTrackIndex) { tensorflow::SequenceExample sequence; - std::vector> tracks = {{5, 3}, {1, 2}}; + std::vector> tracks = {{5, 3}, {1, 2}}; for (int i = 0; i < tracks.size(); ++i) { AddBBoxTrackIndex(tracks[i], &sequence); ASSERT_EQ(GetBBoxTrackIndexSize(sequence), i + 1); @@ -499,7 +499,7 @@ TEST(MediaSequenceTest, RoundTripPredictedBBox) { TEST(MediaSequenceTest, RoundTripPredictedBBoxTimestamp) { tensorflow::SequenceExample sequence; - std::vector timestamps = {3, 6}; + std::vector timestamps = {3, 6}; for (int i = 0; i < timestamps.size(); ++i) { AddPredictedBBoxTimestamp(timestamps[i], &sequence); ASSERT_EQ(GetPredictedBBoxTimestampSize(sequence), i + 1); @@ -659,7 +659,7 @@ TEST(MediaSequenceTest, RoundTripContextFeatureBytes) { TEST(MediaSequenceTest, RoundTripContextFeatureInts) { tensorflow::SequenceExample sequence; std::string feature_key = "TEST"; - std::vector vi = {0, 1, 2, 4}; + std::vector vi = {0, 1, 2, 4}; SetContextFeatureInts(feature_key, vi, &sequence); ASSERT_EQ(GetContextFeatureInts(feature_key, sequence).size(), vi.size()); ASSERT_EQ(GetContextFeatureInts(feature_key, sequence)[3], vi[3]); @@ -725,7 +725,7 @@ TEST(MediaSequenceTest, RoundTripTextContent) { TEST(MediaSequenceTest, RoundTripTextDuration) { tensorflow::SequenceExample sequence; - std::vector timestamps = {4, 7}; + std::vector timestamps = {4, 7}; for (int i = 0; i < timestamps.size(); ++i) { AddTextTimestamp(timestamps[i], &sequence); ASSERT_EQ(GetTextTimestampSize(sequence), i + 1); @@ -765,7 +765,7 @@ TEST(MediaSequenceTest, RoundTripTextEmbedding) { TEST(MediaSequenceTest, RoundTripTextTokenId) { tensorflow::SequenceExample sequence; - std::vector ids = {4, 7}; + std::vector ids = {4, 7}; for (int i = 0; i < ids.size(); ++i) { AddTextTokenId(ids[i], &sequence); ASSERT_EQ(GetTextTokenIdSize(sequence), i + 1); @@ -783,8 +783,8 @@ TEST(MediaSequenceTest, ReconcileMetadataOnEmptySequence) { TEST(MediaSequenceTest, ReconcileMetadataImagestoLabels) { // Need image timestamps and label timestamps. tensorflow::SequenceExample sequence; - SetSegmentStartTimestamp(::std::vector({3, 4}), &sequence); - SetSegmentEndTimestamp(::std::vector({4, 5}), &sequence); + SetSegmentStartTimestamp(::std::vector({3, 4}), &sequence); + SetSegmentEndTimestamp(::std::vector({4, 5}), &sequence); // Skip 0, so the indices are the timestamp - 1 AddImageTimestamp(1, &sequence); @@ -1027,20 +1027,20 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { AddBBoxNumRegions(1, &sequence); AddBBoxNumRegions(1, &sequence); - AddBBoxLabelIndex(::std::vector({1}), &sequence); - AddBBoxLabelIndex(::std::vector({2}), &sequence); + AddBBoxLabelIndex(::std::vector({1}), &sequence); + AddBBoxLabelIndex(::std::vector({2}), &sequence); AddBBoxLabelString(::std::vector({"one"}), &sequence); AddBBoxLabelString(::std::vector({"two"}), &sequence); - AddBBoxClassIndex(::std::vector({1}), &sequence); - AddBBoxClassIndex(::std::vector({2}), &sequence); + AddBBoxClassIndex(::std::vector({1}), &sequence); + AddBBoxClassIndex(::std::vector({2}), &sequence); AddBBoxClassString(::std::vector({"one"}), &sequence); AddBBoxClassString(::std::vector({"two"}), &sequence); - AddBBoxTrackIndex(::std::vector({1}), &sequence); - AddBBoxTrackIndex(::std::vector({2}), &sequence); + AddBBoxTrackIndex(::std::vector({1}), &sequence); + AddBBoxTrackIndex(::std::vector({2}), &sequence); AddBBoxTrackString(::std::vector({"one"}), &sequence); AddBBoxTrackString(::std::vector({"two"}), &sequence); @@ -1083,11 +1083,11 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 1), ::testing::ElementsAreArray({2})); ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 2), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 3), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 4), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxLabelStringAt(sequence, 0), ::testing::ElementsAreArray({"one"})); @@ -1105,11 +1105,11 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { ASSERT_THAT(GetBBoxClassIndexAt(sequence, 1), ::testing::ElementsAreArray({2})); ASSERT_THAT(GetBBoxClassIndexAt(sequence, 2), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxClassIndexAt(sequence, 3), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxClassIndexAt(sequence, 4), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxClassStringAt(sequence, 0), ::testing::ElementsAreArray({"one"})); @@ -1127,11 +1127,11 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 1), ::testing::ElementsAreArray({2})); ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 2), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 3), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 4), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxTrackStringAt(sequence, 0), ::testing::ElementsAreArray({"one"})); diff --git a/mediapipe/util/sequence/media_sequence_util_test.cc b/mediapipe/util/sequence/media_sequence_util_test.cc index 56d3b4868..8709165e3 100644 --- a/mediapipe/util/sequence/media_sequence_util_test.cc +++ b/mediapipe/util/sequence/media_sequence_util_test.cc @@ -253,7 +253,7 @@ TEST_F(MediaSequenceUtilTest, RoundTripFloatList) { TEST_F(MediaSequenceUtilTest, RoundTripInt64List) { tensorflow::SequenceExample sequence_example; std::string key = "key"; - std::vector expected_values{1, 3}; + std::vector expected_values{1, 3}; AddInt64Container(key, expected_values, &sequence_example); auto values = GetInt64sAt(sequence_example, key, 0); ASSERT_EQ(expected_values.size(), values.size()); @@ -302,7 +302,7 @@ TEST_F(MediaSequenceUtilTest, RoundTripContextFeatureList) { } // Test context in64 list. std::string clip_label_index_key = "clip_label_index"; - std::vector clip_label_indices{2, 0}; + std::vector clip_label_indices{2, 0}; SetContextInt64List(clip_label_index_key, clip_label_indices, &sequence_example); for (int i = 0; i < clip_label_indices.size(); ++i) { @@ -333,7 +333,7 @@ TEST_F(MediaSequenceUtilTest, ContextKeyMissing) { TEST_F(MediaSequenceUtilTest, RoundTripFeatureListsFeature) { tensorflow::SequenceExample sequence_example; std::string timestamp_key = "timestamp"; - int64 timestamp = 1000; + int64_t timestamp = 1000; MutableFeatureList(timestamp_key, &sequence_example) ->add_feature() ->mutable_int64_list() @@ -413,7 +413,7 @@ TEST_F(MediaSequenceUtilTest, StringFeature) { TEST_F(MediaSequenceUtilTest, Int64Feature) { tensorflow::SequenceExample example; - int64 test_value = 47; + int64_t test_value = 47; ASSERT_FALSE(HasInt64Feature(example)); SetInt64Feature(test_value, &example); @@ -426,7 +426,7 @@ TEST_F(MediaSequenceUtilTest, Int64Feature) { TEST_F(MediaSequenceUtilTest, FloatFeature) { tensorflow::SequenceExample example; - int64 test_value = 47.0f; + int64_t test_value = 47.0f; ASSERT_FALSE(HasFloatFeature(example)); SetFloatFeature(test_value, &example); @@ -464,7 +464,7 @@ TEST_F(MediaSequenceUtilTest, StringVectorFeature) { TEST_F(MediaSequenceUtilTest, Int64VectorFeature) { tensorflow::SequenceExample example; - ::std::vector test_value = {47, 42}; + ::std::vector test_value = {47, 42}; ASSERT_FALSE(HasInt64VectorFeature(example)); ASSERT_EQ(0, GetInt64VectorFeatureSize(example)); @@ -535,7 +535,7 @@ TEST_F(MediaSequenceUtilTest, StringFeatureList) { TEST_F(MediaSequenceUtilTest, Int64FeatureList) { tensorflow::SequenceExample example; - ::std::vector test_value = {47, 42}; + ::std::vector test_value = {47, 42}; ASSERT_FALSE(HasInt64FeatureList(example)); ASSERT_EQ(0, GetInt64FeatureListSize(example)); @@ -602,7 +602,7 @@ TEST_F(MediaSequenceUtilTest, VectorStringFeatureList) { TEST_F(MediaSequenceUtilTest, VectorInt64FeatureList) { tensorflow::SequenceExample example; - ::std::vector<::std::vector> test_value = {{47, 42}, {3, 5}}; + ::std::vector<::std::vector> test_value = {{47, 42}, {3, 5}}; ASSERT_FALSE(HasVectorInt64FeatureList(example)); ASSERT_EQ(0, GetVectorInt64FeatureListSize(example)); @@ -704,8 +704,8 @@ TEST_F(MediaSequenceUtilTest, VariablePrefixStringFeature) { TEST_F(MediaSequenceUtilTest, FixedPrefixInt64Feature) { tensorflow::SequenceExample example; - int64 test_value_1 = 47; - int64 test_value_2 = 49; + int64_t test_value_1 = 47; + int64_t test_value_2 = 49; ASSERT_FALSE(HasOneInt64Feature(example)); SetOneInt64Feature(test_value_1, &example); @@ -727,8 +727,8 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixInt64Feature) { TEST_F(MediaSequenceUtilTest, FixedPrefixFloatFeature) { tensorflow::SequenceExample example; - int64 test_value_1 = 47.0f; - int64 test_value_2 = 49.0f; + int64_t test_value_1 = 47.0f; + int64_t test_value_2 = 49.0f; ASSERT_FALSE(HasOneFloatFeature(example)); SetOneFloatFeature(test_value_1, &example); @@ -795,8 +795,8 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixStringVectorFeature) { TEST_F(MediaSequenceUtilTest, FixedPrefixInt64VectorFeature) { tensorflow::SequenceExample example; - ::std::vector test_value_1 = {47, 42}; - ::std::vector test_value_2 = {49, 47}; + ::std::vector test_value_1 = {47, 42}; + ::std::vector test_value_2 = {49, 47}; ASSERT_FALSE(HasOneInt64VectorFeature(example)); ASSERT_EQ(0, GetOneInt64VectorFeatureSize(example)); @@ -905,7 +905,7 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixStringFeatureList) { TEST_F(MediaSequenceUtilTest, FixedPrefixInt64FeatureList) { tensorflow::SequenceExample example; - ::std::vector test_value = {47, 42}; + ::std::vector test_value = {47, 42}; ASSERT_FALSE(HasInt64FeatureList(example)); ASSERT_EQ(0, GetInt64FeatureListSize(example)); @@ -990,8 +990,8 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixVectorStringFeatureList) { TEST_F(MediaSequenceUtilTest, FixedPrefixVectorInt64FeatureList) { tensorflow::SequenceExample example; - ::std::vector<::std::vector> test_value_1 = {{47, 42}, {3, 5}}; - ::std::vector<::std::vector> test_value_2 = {{49, 47}, {3, 5}}; + ::std::vector<::std::vector> test_value_1 = {{47, 42}, {3, 5}}; + ::std::vector<::std::vector> test_value_2 = {{49, 47}, {3, 5}}; ASSERT_FALSE(HasOneVectorInt64FeatureList(example)); ASSERT_EQ(0, GetOneVectorInt64FeatureListSize(example)); From c5bd34ddb0dc46a79d2002e755a5984d0650a8a6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 04:23:26 -0700 Subject: [PATCH 49/63] Internal change PiperOrigin-RevId: 522014435 --- mediapipe/gpu/gpu_shared_data_internal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 49e9cf22a..f542f0bb2 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -119,7 +119,7 @@ GpuResources::~GpuResources() { extern const GraphService kGpuService; absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { - CHECK(ContainsKey(node->Contract().ServiceRequests(), kGpuService.key)); + CHECK(node->Contract().ServiceRequests().contains(kGpuService.key)); std::string node_id = node->GetCalculatorState().NodeName(); std::string node_type = node->GetCalculatorState().CalculatorType(); std::string context_key; From 5615c1e459daf86e7794fcafde744b744605522f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 10:56:03 -0700 Subject: [PATCH 50/63] Delete duplicate public APIs in object detector PiperOrigin-RevId: 522098326 --- .../python/vision/object_detector/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mediapipe/model_maker/python/vision/object_detector/__init__.py b/mediapipe/model_maker/python/vision/object_detector/__init__.py index 6b60760d4..ef7a92010 100644 --- a/mediapipe/model_maker/python/vision/object_detector/__init__.py +++ b/mediapipe/model_maker/python/vision/object_detector/__init__.py @@ -28,3 +28,14 @@ HParams = hyperparameters.HParams QATHParams = hyperparameters.QATHParams Dataset = dataset.Dataset ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions + +# Remove duplicated and non-public API +del dataset +del dataset_util # pylint: disable=undefined-variable +del hyperparameters +del model # pylint: disable=undefined-variable +del model_options +del model_spec +del object_detector +del object_detector_options +del preprocessor # pylint: disable=undefined-variable From 6605f551e77689b5a1e91ac48114c9bad9ae831f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 14:09:08 -0700 Subject: [PATCH 51/63] Object Detector add batch_size and train_data to get_steps_per_epoch. PiperOrigin-RevId: 522149938 --- .../python/vision/object_detector/object_detector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector.py b/mediapipe/model_maker/python/vision/object_detector/object_detector.py index 316df85a9..2d1d92ef3 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector.py @@ -105,7 +105,9 @@ class ObjectDetector(classifier.Classifier): """ self._optimizer = self._create_optimizer( model_util.get_steps_per_epoch( - self._hparams.steps_per_epoch, + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, + train_data=train_data, ) ) self._create_model() From 065d7507818e9bf6f93ae38cbab805c6a4414907 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 14:22:26 -0700 Subject: [PATCH 52/63] Add VEC32F4 support to ImageFrame PiperOrigin-RevId: 522153305 --- mediapipe/framework/formats/image_format.proto | 3 +++ mediapipe/framework/formats/image_frame.cc | 6 ++++++ mediapipe/framework/formats/image_frame_opencv.cc | 3 +++ mediapipe/framework/formats/image_frame_opencv_test.cc | 3 +++ mediapipe/framework/formats/image_opencv.cc | 3 +++ mediapipe/framework/tool/test_util.cc | 1 + mediapipe/gpu/gpu_buffer_format.cc | 5 ++++- mediapipe/objc/util.cc | 8 ++++++++ mediapipe/python/pybind/image.cc | 9 +++++---- mediapipe/python/pybind/packet_creator.cc | 6 ++++-- 10 files changed, 40 insertions(+), 7 deletions(-) diff --git a/mediapipe/framework/formats/image_format.proto b/mediapipe/framework/formats/image_format.proto index 61e004ac6..e9b69a4c1 100644 --- a/mediapipe/framework/formats/image_format.proto +++ b/mediapipe/framework/formats/image_format.proto @@ -69,6 +69,9 @@ message ImageFormat { // Two floats per pixel. VEC32F2 = 12; + // Four floats per pixel. + VEC32F4 = 13; + // LAB, interleaved: one byte for L, then one byte for a, then one // byte for b for each pixel. LAB8 = 10; diff --git a/mediapipe/framework/formats/image_frame.cc b/mediapipe/framework/formats/image_frame.cc index 772c91014..2de819a35 100644 --- a/mediapipe/framework/formats/image_frame.cc +++ b/mediapipe/framework/formats/image_frame.cc @@ -280,6 +280,8 @@ int ImageFrame::NumberOfChannelsForFormat(ImageFormat::Format format) { return 1; case ImageFormat::VEC32F2: return 2; + case ImageFormat::VEC32F4: + return 4; case ImageFormat::LAB8: return 3; case ImageFormat::SBGRA: @@ -309,6 +311,8 @@ int ImageFrame::ChannelSizeForFormat(ImageFormat::Format format) { return sizeof(float); case ImageFormat::VEC32F2: return sizeof(float); + case ImageFormat::VEC32F4: + return sizeof(float); case ImageFormat::LAB8: return sizeof(uint8_t); case ImageFormat::SBGRA: @@ -338,6 +342,8 @@ int ImageFrame::ByteDepthForFormat(ImageFormat::Format format) { return 4; case ImageFormat::VEC32F2: return 4; + case ImageFormat::VEC32F4: + return 4; case ImageFormat::LAB8: return 1; case ImageFormat::SBGRA: diff --git a/mediapipe/framework/formats/image_frame_opencv.cc b/mediapipe/framework/formats/image_frame_opencv.cc index 940e18263..1ba8c719f 100644 --- a/mediapipe/framework/formats/image_frame_opencv.cc +++ b/mediapipe/framework/formats/image_frame_opencv.cc @@ -59,6 +59,9 @@ int GetMatType(const mediapipe::ImageFormat::Format format) { case mediapipe::ImageFormat::VEC32F2: type = CV_32FC2; break; + case mediapipe::ImageFormat::VEC32F4: + type = CV_32FC4; + break; case mediapipe::ImageFormat::LAB8: type = CV_8U; break; diff --git a/mediapipe/framework/formats/image_frame_opencv_test.cc b/mediapipe/framework/formats/image_frame_opencv_test.cc index ae6f90f81..87d2ffb36 100644 --- a/mediapipe/framework/formats/image_frame_opencv_test.cc +++ b/mediapipe/framework/formats/image_frame_opencv_test.cc @@ -113,6 +113,7 @@ TEST(ImageFrameOpencvTest, ImageFormats) { ImageFrame frame_g16(ImageFormat::GRAY16, i_width, i_height); ImageFrame frame_v32f1(ImageFormat::VEC32F1, i_width, i_height); ImageFrame frame_v32f2(ImageFormat::VEC32F2, i_width, i_height); + ImageFrame frame_v32f4(ImageFormat::VEC32F4, i_width, i_height); ImageFrame frame_c3(ImageFormat::SRGB, i_width, i_height); ImageFrame frame_c4(ImageFormat::SRGBA, i_width, i_height); @@ -120,6 +121,7 @@ TEST(ImageFrameOpencvTest, ImageFormats) { cv::Mat mat_g16 = formats::MatView(&frame_g16); cv::Mat mat_v32f1 = formats::MatView(&frame_v32f1); cv::Mat mat_v32f2 = formats::MatView(&frame_v32f2); + cv::Mat mat_v32f4 = formats::MatView(&frame_v32f4); cv::Mat mat_c3 = formats::MatView(&frame_c3); cv::Mat mat_c4 = formats::MatView(&frame_c4); @@ -127,6 +129,7 @@ TEST(ImageFrameOpencvTest, ImageFormats) { EXPECT_EQ(mat_g16.type(), CV_16UC1); EXPECT_EQ(mat_v32f1.type(), CV_32FC1); EXPECT_EQ(mat_v32f2.type(), CV_32FC2); + EXPECT_EQ(mat_v32f4.type(), CV_32FC4); EXPECT_EQ(mat_c3.type(), CV_8UC3); EXPECT_EQ(mat_c4.type(), CV_8UC4); } diff --git a/mediapipe/framework/formats/image_opencv.cc b/mediapipe/framework/formats/image_opencv.cc index 7d9ce4a13..498c7831f 100644 --- a/mediapipe/framework/formats/image_opencv.cc +++ b/mediapipe/framework/formats/image_opencv.cc @@ -60,6 +60,9 @@ int GetMatType(const mediapipe::ImageFormat::Format format) { case mediapipe::ImageFormat::VEC32F2: type = CV_32FC2; break; + case mediapipe::ImageFormat::VEC32F4: + type = CV_32FC4; + break; case mediapipe::ImageFormat::LAB8: type = CV_8U; break; diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index d05171d20..5642941e9 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -191,6 +191,7 @@ absl::Status CompareImageFrames(const ImageFrame& image1, max_alpha_diff, max_avg_diff, diff_image); case ImageFormat::VEC32F1: case ImageFormat::VEC32F2: + case ImageFormat::VEC32F4: return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, max_avg_diff, diff_image); default: diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 8e2e3858e..a820f04d6 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -204,6 +204,8 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { return ImageFormat::SRGB; case GpuBufferFormat::kTwoComponentFloat32: return ImageFormat::VEC32F2; + case GpuBufferFormat::kRGBAFloat128: + return ImageFormat::VEC32F4; case GpuBufferFormat::kRGBA32: // TODO: this likely maps to ImageFormat::SRGBA case GpuBufferFormat::kGrayHalf16: @@ -211,7 +213,6 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { case GpuBufferFormat::kTwoComponent8: case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kRGBAHalf64: - case GpuBufferFormat::kRGBAFloat128: case GpuBufferFormat::kNV12: case GpuBufferFormat::kNV21: case GpuBufferFormat::kI420: @@ -232,6 +233,8 @@ GpuBufferFormat GpuBufferFormatForImageFormat(ImageFormat::Format format) { return GpuBufferFormat::kGrayFloat32; case ImageFormat::VEC32F2: return GpuBufferFormat::kTwoComponentFloat32; + case ImageFormat::VEC32F4: + return GpuBufferFormat::kRGBAFloat128; case ImageFormat::GRAY8: return GpuBufferFormat::kOneComponent8; case ImageFormat::YCBCR420P: diff --git a/mediapipe/objc/util.cc b/mediapipe/objc/util.cc index 895463060..36ad4e195 100644 --- a/mediapipe/objc/util.cc +++ b/mediapipe/objc/util.cc @@ -365,6 +365,10 @@ absl::StatusOr> CreateCVPixelBufferForImageFrame( pixel_format = kCVPixelFormatType_TwoComponent32Float; break; + case mediapipe::ImageFormat::VEC32F4: + pixel_format = kCVPixelFormatType_128RGBAFloat; + break; + default: return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "unsupported ImageFrame format: " << image_format; @@ -440,6 +444,10 @@ absl::StatusOr> CreateCVPixelBufferCopyingImageFrame( pixel_format = kCVPixelFormatType_TwoComponent32Float; break; + case mediapipe::ImageFormat::VEC32F4: + pixel_format = kCVPixelFormatType_128RGBAFloat; + break; + default: return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "unsupported ImageFrame format: " << image_format; diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 6049abfae..5e5ba7530 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -120,16 +120,17 @@ void ImageSubmodule(pybind11::module* module) { py::init([](mediapipe::ImageFormat::Format format, const py::array_t& data) { if (format != mediapipe::ImageFormat::VEC32F1 && - format != mediapipe::ImageFormat::VEC32F2) { + format != mediapipe::ImageFormat::VEC32F2 && + format != mediapipe::ImageFormat::VEC32F4) { throw RaisePyError( PyExc_RuntimeError, - "float image data should be either VEC32F1 or VEC32F2 " - "MediaPipe image formats."); + "float image data should be either VEC32F1, VEC32F2, or " + "VEC32F4 MediaPipe image formats."); } return Image(std::shared_ptr( CreateImageFrame(format, data))); }), - R"doc(For float data type, valid ImageFormat are VEC32F1 and VEC32F2.)doc", + R"doc(For float data type, valid ImageFormat are VEC32F1, VEC32F2, and VEC32F4.)doc", py::arg("image_format"), py::arg("data").noconvert()); image.def( diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index 92e695020..c8ae7c259 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -42,7 +42,8 @@ Packet CreateImageFramePacket(mediapipe::ImageFormat::Format format, format == mediapipe::ImageFormat::SRGBA64) { return Adopt(CreateImageFrame(format, data, copy).release()); } else if (format == mediapipe::ImageFormat::VEC32F1 || - format == mediapipe::ImageFormat::VEC32F2) { + format == mediapipe::ImageFormat::VEC32F2 || + format == mediapipe::ImageFormat::VEC32F4) { return Adopt(CreateImageFrame(format, data, copy).release()); } throw RaisePyError(PyExc_RuntimeError, @@ -63,7 +64,8 @@ Packet CreateImagePacket(mediapipe::ImageFormat::Format format, return MakePacket(std::shared_ptr( CreateImageFrame(format, data, copy))); } else if (format == mediapipe::ImageFormat::VEC32F1 || - format == mediapipe::ImageFormat::VEC32F2) { + format == mediapipe::ImageFormat::VEC32F2 || + format == mediapipe::ImageFormat::VEC32F4) { return MakePacket(std::shared_ptr( CreateImageFrame(format, data, copy))); } From f67007d07767c03462de009e5a8003a1d121af40 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 5 Apr 2023 16:18:27 -0700 Subject: [PATCH 53/63] Remove platform information for x86 PiperOrigin-RevId: 522182423 --- setup.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/setup.py b/setup.py index fe8f2dc68..6e7493caa 100644 --- a/setup.py +++ b/setup.py @@ -348,10 +348,7 @@ class BuildExtension(build_ext.build_ext): for ext in self.extensions: target_name = self.get_ext_fullpath(ext.name) # Build x86 - self._build_binary( - ext, - ['--cpu=darwin_x86_64', '--ios_multi_cpus=i386,x86_64,armv7,arm64'], - ) + self._build_binary(ext) x86_name = self.get_ext_fullpath(ext.name) # Build Arm64 ext.name = ext.name + '.arm64' From 7fe87936e5b6166d9ddb4a217768cb3f6fb84e34 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Wed, 5 Apr 2023 18:13:07 -0700 Subject: [PATCH 54/63] Internal change PiperOrigin-RevId: 522206591 --- mediapipe/framework/api2/builder.h | 7 ++++++- mediapipe/framework/deps/registration.h | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index da09acc83..ee9796e49 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -425,7 +425,10 @@ using GenericNode = Node; template class Node : public NodeBase { public: - Node() : NodeBase(std::string(Calc::kCalculatorName)) {} + Node() + : NodeBase( + FunctionRegistry::GetLookupName(Calc::kCalculatorName)) {} + // Overrides the built-in calculator type string with the provided argument. // Can be used to create nodes from pure interfaces. // TODO: only use this for pure interfaces @@ -546,6 +549,7 @@ class Graph { // Creates a node of a specific type. Should be used for pure interfaces, // which do not have a built-in type string. + // `type` is a calculator type-name with dot-separated namespaces. template Node& AddNode(absl::string_view type) { auto node = @@ -557,6 +561,7 @@ class Graph { // Creates a generic node, with no compile-time checking of inputs and // outputs. This can be used for calculators whose contract is not visible. + // `type` is a calculator type-name with dot-separated namespaces. GenericNode& AddNode(absl::string_view type) { auto node = std::make_unique(std::string(type.data(), type.size())); diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index cc8ba03fe..7965539b6 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -301,6 +301,18 @@ class FunctionRegistry { return cxx_name; } + // Returns a type name with '.' separated namespaces. + static std::string GetLookupName(const absl::string_view cxx_type_name) { + constexpr absl::string_view kCxxSep = "::"; + constexpr absl::string_view kNameSep = "."; + std::vector names = + absl::StrSplit(cxx_type_name, kCxxSep); + if (names[0].empty()) { + names.erase(names.begin()); + } + return absl::StrJoin(names, kNameSep); + } + private: mutable absl::Mutex lock_; absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); From d5def9e24dd4d769759d763b62232ff726716789 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 20:30:05 -0700 Subject: [PATCH 55/63] Image segmenter output both confidence masks and category mask optionally. PiperOrigin-RevId: 522227345 --- .../tensors_to_segmentation_calculator.cc | 31 ++++-- .../vision/image_segmenter/image_segmenter.cc | 49 +++++++--- .../vision/image_segmenter/image_segmenter.h | 10 +- .../image_segmenter/image_segmenter_graph.cc | 95 ++++++++++++------- .../image_segmenter/image_segmenter_result.h | 2 +- .../image_segmenter/image_segmenter_test.cc | 25 ++--- 6 files changed, 139 insertions(+), 73 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 49ad18029..790285546 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" @@ -210,8 +211,9 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, } // namespace // Converts Tensors from a vector of Tensor to Segmentation masks. The -// calculator always output confidence masks, and an optional category mask if -// CATEGORY_MASK is connected. +// calculator can output optional confidence masks if CONFIDENCE_MASK is +// connected, and an optional category mask if CATEGORY_MASK is connected. At +// least one of CONFIDENCE_MASK and CATEGORY_MASK must be connected. // // Performs optional resizing to OUTPUT_SIZE dimension if provided, // otherwise the segmented masks is the same size as input tensor. @@ -296,6 +298,13 @@ absl::Status TensorsToSegmentationCalculator::Open( SegmenterOptions::UNSPECIFIED) << "Must specify output_type as one of " "[CONFIDENCE_MASK|CATEGORY_MASK]."; + } else { + if (!cc->Outputs().HasTag("CONFIDENCE_MASK") && + !cc->Outputs().HasTag("CATEGORY_MASK")) { + return absl::InvalidArgumentError( + "At least one of CONFIDENCE_MASK and CATEGORY_MASK must be " + "connected."); + } } #ifdef __EMSCRIPTEN__ MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); @@ -366,14 +375,16 @@ absl::Status TensorsToSegmentationCalculator::Process( return absl::OkStatus(); } - std::vector confidence_masks = - ProcessForConfidenceMaskCpu(input_shape, - {/* height= */ output_height, - /* width= */ output_width, - /* channels= */ input_shape.channels}, - options_.segmenter_options(), tensors_buffer); - for (int i = 0; i < confidence_masks.size(); ++i) { - kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); + if (cc->Outputs().HasTag("CONFIDENCE_MASK")) { + std::vector confidence_masks = ProcessForConfidenceMaskCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ input_shape.channels}, + options_.segmenter_options(), tensors_buffer); + for (int i = 0; i < confidence_masks.size(); ++i) { + kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); + } } if (cc->Outputs().HasTag("CATEGORY_MASK")) { kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 8f03ff086..33c868e05 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -60,15 +60,19 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, - bool output_category_mask, bool enable_flow_limiting) { + bool output_confidence_masks, bool output_category_mask, + bool enable_flow_limiting) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >> - graph.Out(kConfidenceMasksTag); + if (output_confidence_masks) { + task_subgraph.Out(kConfidenceMasksTag) + .SetName(kConfidenceMasksStreamName) >> + graph.Out(kConfidenceMasksTag); + } if (output_category_mask) { task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); @@ -135,11 +139,17 @@ absl::StatusOr> GetLabelsFromGraphConfig( absl::StatusOr> ImageSegmenter::Create( std::unique_ptr options) { + if (!options->output_confidence_masks && !options->output_category_mask) { + return absl::InvalidArgumentError( + "At least one of `output_confidence_masks` and `output_category_mask` " + "must be set."); + } auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); tasks::core::PacketsCallback packets_callback = nullptr; if (options->result_callback) { auto result_callback = options->result_callback; bool output_category_mask = options->output_category_mask; + bool output_confidence_masks = options->output_confidence_masks; packets_callback = [=](absl::StatusOr status_or_packets) { if (!status_or_packets.ok()) { @@ -151,8 +161,12 @@ absl::StatusOr> ImageSegmenter::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } - Packet confidence_masks = - status_or_packets.value()[kConfidenceMasksStreamName]; + std::optional> confidence_masks; + if (output_confidence_masks) { + confidence_masks = + status_or_packets.value()[kConfidenceMasksStreamName] + .Get>(); + } std::optional category_mask; if (output_category_mask) { category_mask = @@ -160,23 +174,24 @@ absl::StatusOr> ImageSegmenter::Create( } Packet image_packet = status_or_packets.value()[kImageOutStreamName]; result_callback( - {{confidence_masks.Get>(), category_mask}}, - image_packet.Get(), - confidence_masks.Timestamp().Value() / - kMicroSecondsPerMilliSecond); + {{confidence_masks, category_mask}}, image_packet.Get(), + image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); }; } auto image_segmenter = core::VisionTaskApiFactory::Create( CreateGraphConfig( - std::move(options_proto), options->output_category_mask, + std::move(options_proto), options->output_confidence_masks, + options->output_category_mask, options->running_mode == core::RunningMode::LIVE_STREAM), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)); if (!image_segmenter.ok()) { return image_segmenter.status(); } + image_segmenter.value()->output_confidence_masks_ = + options->output_confidence_masks; image_segmenter.value()->output_category_mask_ = options->output_category_mask; ASSIGN_OR_RETURN( @@ -203,8 +218,11 @@ absl::StatusOr ImageSegmenter::Segment( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, MakePacket(std::move(norm_rect))}})); - std::vector confidence_masks = - output_packets[kConfidenceMasksStreamName].Get>(); + std::optional> confidence_masks; + if (output_confidence_masks_) { + confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + } std::optional category_mask; if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); @@ -233,8 +251,11 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( {kNormRectStreamName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - std::vector confidence_masks = - output_packets[kConfidenceMasksStreamName].Get>(); + std::optional> confidence_masks; + if (output_confidence_masks_) { + confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + } std::optional category_mask; if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 1d18e3903..352d6b273 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -53,6 +53,9 @@ struct ImageSegmenterOptions { // Metadata, if any. Defaults to English. std::string display_names_locale = "en"; + // Whether to output confidence masks. + bool output_confidence_masks = true; + // Whether to output category mask. bool output_category_mask = false; @@ -77,8 +80,10 @@ struct ImageSegmenterOptions { // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. // Output ImageSegmenterResult: -// Provides confidence masks and an optional category mask if -// `output_category_mask` is set true. +// Provides optional confidence masks if `output_confidence_masks` is set +// true, and an optional category mask if `output_category_mask` is set +// true. At least one of `output_confidence_masks` and `output_category_mask` +// must be set to true. // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { @@ -167,6 +172,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { private: std::vector labels_; + bool output_confidence_masks_; bool output_category_mask_; }; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 4b9e7618b..840e7933a 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -326,8 +326,10 @@ absl::StatusOr ConvertImageToTensors( } // An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs -// semantic segmentation. The graph always output confidence masks, and an -// optional category mask if CATEGORY_MASK is connected. +// semantic segmentation. The graph can output optional confidence masks if +// CONFIDENCE_MASKS is connected, and an optional category mask if CATEGORY_MASK +// is connected. At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and +// CATEGORY_MASK must be connected. // // Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and // CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular @@ -347,7 +349,7 @@ absl::StatusOr ConvertImageToTensors( // CONFIDENCE_MASK - mediapipe::Image @Multiple // Confidence masks for individual category. Confidence mask of single // category can be accessed by index based output stream. -// CONFIDENCE_MASKS - std::vector +// CONFIDENCE_MASKS - std::vector @Optional // The output confidence masks grouped in a vector. // CATEGORY_MASK - mediapipe::Image @Optional // Optional Category mask. @@ -356,7 +358,7 @@ absl::StatusOr ConvertImageToTensors( // // Example: // node { -// calculator: "mediapipe.tasks.vision.ImageSegmenterGraph" +// calculator: "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" // input_stream: "IMAGE:image" // output_stream: "SEGMENTATION:segmented_masks" // options { @@ -382,17 +384,20 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; const auto& options = sc->Options(); + // TODO: remove deprecated output type support. + if (!options.segmenter_options().has_output_type()) { + MP_RETURN_IF_ERROR(SanityCheck(sc)); + } ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( options, *model_resources, graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], - HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph)); + graph[Input::Optional(kNormRectTag)], graph)); - auto& merge_images_to_vector = - graph.AddNode("MergeImagesToVectorCalculator"); // TODO: remove deprecated output type support. if (options.segmenter_options().has_output_type()) { + auto& merge_images_to_vector = + graph.AddNode("MergeImagesToVectorCalculator"); for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { output_streams.segmented_masks->at(i) >> merge_images_to_vector[Input::Multiple("")][i]; @@ -402,14 +407,18 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { merge_images_to_vector.Out("") >> graph[Output>(kGroupedSegmentationTag)]; } else { - for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { - output_streams.confidence_masks->at(i) >> - merge_images_to_vector[Input::Multiple("")][i]; - output_streams.confidence_masks->at(i) >> - graph[Output::Multiple(kConfidenceMaskTag)][i]; + if (output_streams.confidence_masks) { + auto& merge_images_to_vector = + graph.AddNode("MergeImagesToVectorCalculator"); + for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { + output_streams.confidence_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.confidence_masks->at(i) >> + graph[Output::Multiple(kConfidenceMaskTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>::Optional(kConfidenceMasksTag)]; } - merge_images_to_vector.Out("") >> - graph[Output>(kConfidenceMasksTag)]; if (output_streams.category_mask) { *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; } @@ -419,6 +428,19 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { } private: + absl::Status SanityCheck(mediapipe::SubgraphContext* sc) { + const auto& node = sc->OriginalNode(); + output_confidence_masks_ = HasOutput(node, kConfidenceMaskTag) || + HasOutput(node, kConfidenceMasksTag); + output_category_mask_ = HasOutput(node, kCategoryMaskTag); + if (!output_confidence_masks_ && !output_category_mask_) { + return absl::InvalidArgumentError( + "At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK " + "must be connected."); + } + return absl::OkStatus(); + } + // Adds a mediapipe image segmentation task pipeline graph into the provided // builder::Graph instance. The segmentation pipeline takes images // (mediapipe::Image) as the input and returns segmented image mask as output. @@ -431,8 +453,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Source norm_rect_in, bool output_category_mask, - Graph& graph) { + Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -485,26 +506,32 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { /*category_mask=*/std::nullopt, /*image=*/image_and_tensors.image}; } else { - ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, - GetOutputTensor(model_resources)); - int segmentation_streams_num = *output_tensor->shape()->rbegin(); - std::vector> confidence_masks; - confidence_masks.reserve(segmentation_streams_num); - for (int i = 0; i < segmentation_streams_num; ++i) { - confidence_masks.push_back(Source( - tensor_to_images[Output::Multiple(kConfidenceMaskTag)][i])); + std::optional>> confidence_masks; + if (output_confidence_masks_) { + ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, + GetOutputTensor(model_resources)); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); + confidence_masks = std::vector>(); + confidence_masks->reserve(segmentation_streams_num); + for (int i = 0; i < segmentation_streams_num; ++i) { + confidence_masks->push_back(Source( + tensor_to_images[Output::Multiple(kConfidenceMaskTag)] + [i])); + } } - return ImageSegmenterOutputs{ - /*segmented_masks=*/std::nullopt, - /*confidence_masks=*/confidence_masks, - /*category_mask=*/ - output_category_mask - ? std::make_optional( - tensor_to_images[Output(kCategoryMaskTag)]) - : std::nullopt, - /*image=*/image_and_tensors.image}; + std::optional> category_mask; + if (output_category_mask_) { + category_mask = tensor_to_images[Output(kCategoryMaskTag)]; + } + return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, + /*confidence_masks=*/confidence_masks, + /*category_mask=*/category_mask, + /*image=*/image_and_tensors.image}; } } + + bool output_confidence_masks_ = false; + bool output_category_mask_ = false; }; REGISTER_MEDIAPIPE_GRAPH( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h index fb2ec05f1..f14ee7a90 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -29,7 +29,7 @@ namespace image_segmenter { struct ImageSegmenterResult { // Multiple masks of float image in VEC32F1 format where, for each mask, each // pixel represents the prediction confidence, usually in the [0, 1] range. - std::vector confidence_masks; + std::optional> confidence_masks; // A category mask of uint8 image in GRAY8 format where each pixel represents // the class which the pixel in the original image was predicted to belong to. std::optional category_mask; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 1e4387491..0c5a61486 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -278,6 +278,7 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_confidence_masks = false; options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -306,7 +307,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 21); + EXPECT_EQ(result.confidence_masks->size(), 21); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); @@ -315,7 +316,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - result.confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -336,7 +337,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { image_processing_options.rotation_degrees = -90; MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image, image_processing_options)); - EXPECT_EQ(result.confidence_masks.size(), 21); + EXPECT_EQ(result.confidence_masks->size(), 21); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), @@ -346,7 +347,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - result.confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -384,7 +385,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 2); + EXPECT_EQ(result.confidence_masks->size(), 2); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -395,7 +396,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { // Selfie category index 1. cv::Mat selfie_mask = mediapipe::formats::MatView( - result.confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -409,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 1); + EXPECT_EQ(result.confidence_masks->size(), 1); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -419,7 +420,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - result.confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -434,7 +435,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 1); + EXPECT_EQ(result.confidence_masks->size(), 1); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( @@ -445,7 +446,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - result.confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -506,10 +507,10 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 2); + EXPECT_EQ(result.confidence_masks->size(), 2); cv::Mat hair_mask = mediapipe::formats::MatView( - result.confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), From 7ae4d0175a3461dbf6d25d93fb5a047a8db4d592 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 21:40:17 -0700 Subject: [PATCH 56/63] CL will fix the typos in the tasks files PiperOrigin-RevId: 522240681 --- .../model_maker/python/core/data/dataset.py | 2 +- .../model_maker/python/core/hyperparameters.py | 2 +- .../python/vision/gesture_recognizer/BUILD | 2 +- .../python/vision/object_detector/dataset.py | 4 ++-- .../vision/object_detector/dataset_util.py | 2 +- .../cc/metadata/tests/metadata_version_test.cc | 8 ++++---- .../tasks/cc/vision/core/base_vision_task_api.h | 2 +- .../hand_gesture_recognizer_graph.cc | 8 ++++---- .../vision/utils/sources/MPPImage+TestUtils.h | 4 ++-- .../google/mediapipe/tasks/core/TaskRunner.java | 2 +- .../tasks/vision/facestylizer/FaceStylizer.java | 8 ++++---- .../processors/classifier_result.test.ts | 16 ++++++++-------- 12 files changed, 30 insertions(+), 30 deletions(-) diff --git a/mediapipe/model_maker/python/core/data/dataset.py b/mediapipe/model_maker/python/core/data/dataset.py index 113969384..3b4182c14 100644 --- a/mediapipe/model_maker/python/core/data/dataset.py +++ b/mediapipe/model_maker/python/core/data/dataset.py @@ -84,7 +84,7 @@ class Dataset(object): create randomness during model training. preprocess: A function taking three arguments in order, feature, label and boolean is_training. - drop_remainder: boolean, whether the finaly batch drops remainder. + drop_remainder: boolean, whether the finally batch drops remainder. Returns: A TF dataset ready to be consumed by Keras model. diff --git a/mediapipe/model_maker/python/core/hyperparameters.py b/mediapipe/model_maker/python/core/hyperparameters.py index 5cff30930..e6848e0de 100644 --- a/mediapipe/model_maker/python/core/hyperparameters.py +++ b/mediapipe/model_maker/python/core/hyperparameters.py @@ -32,7 +32,7 @@ class BaseHParams: epochs: Number of training iterations over the dataset. steps_per_epoch: An optional integer indicate the number of training steps per epoch. If not set, the training pipeline calculates the default steps - per epoch as the training dataset size devided by batch size. + per epoch as the training dataset size divided by batch size. shuffle: True if the dataset is shuffled before training. export_dir: The location of the model checkpoint files. distribution_strategy: A string specifying which Distribution Strategy to diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 77ed2e016..e96421593 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -21,7 +21,7 @@ package( default_visibility = ["//mediapipe:__subpackages__"], ) -# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced +# TODO: Remove the unnecessary test data once the demo data are moved to an open-sourced # directory. filegroup( name = "testdata", diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset.py b/mediapipe/model_maker/python/vision/object_detector/dataset.py index 741263129..f260c82c5 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset.py @@ -155,8 +155,8 @@ class Dataset(classification_dataset.ClassificationDataset): ObjectDetectorDataset object. """ # Get TFRecord Files - tfrecord_file_patten = cache_prefix + '*.tfrecord' - matched_files = tf.io.gfile.glob(tfrecord_file_patten) + tfrecord_file_pattern = cache_prefix + '*.tfrecord' + matched_files = tf.io.gfile.glob(tfrecord_file_pattern) if not matched_files: raise ValueError('TFRecord files are empty.') diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py index 020c94501..440d45945 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py @@ -345,7 +345,7 @@ def _coco_annotations_to_lists( Args: bbox_annotations: List of dicts with keys ['bbox', 'category_id'] image_height: Height of image - image_width: Width of iamge + image_width: Width of image Returns: (data, num_annotations_skipped) tuple where data contains the keys: diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc index 32ff51482..63cd2ff9c 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc @@ -111,7 +111,7 @@ TEST(MetadataVersionTest, TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForModelMetadataVocabAssociatedFiles) { // Creates a metadata flatbuffer with the field, - // ModelMetadata.associated_fiels, populated with the vocabulary file type. + // ModelMetadata.associated_fields, populated with the vocabulary file type. FlatBufferBuilder builder(1024); AssociatedFileBuilder associated_file_builder(builder); associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY); @@ -159,8 +159,8 @@ TEST(MetadataVersionTest, TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForInputMetadataVocabAssociatedFiles) { // Creates a metadata flatbuffer with the field, - // SubGraphMetadata.input_tensor_metadata.associated_fiels, populated with the - // vocabulary file type. + // SubGraphMetadata.input_tensor_metadata.associated_fields, populated with + // the vocabulary file type. FlatBufferBuilder builder(1024); AssociatedFileBuilder associated_file_builder(builder); associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY); @@ -184,7 +184,7 @@ TEST(MetadataVersionTest, TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOutputMetadataVocabAssociatedFiles) { // Creates a metadata flatbuffer with the field, - // SubGraphMetadata.output_tensor_metadata.associated_fiels, populated with + // SubGraphMetadata.output_tensor_metadata.associated_fields, populated with // the vocabulary file type. FlatBufferBuilder builder(1024); AssociatedFileBuilder associated_file_builder(builder); diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c56f350b2..8e6105e18 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -188,7 +188,7 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { // For 90° and 270° rotations, we need to swap width and height. // This is due to the internal behavior of ImageToTensorCalculator, which: // - first denormalizes the provided rect by multiplying the rect width or - // height by the image width or height, repectively. + // height by the image width or height, respectively. // - then rotates this by denormalized rect by the provided rotation, and // uses this for cropping, // - then finally rotates this back. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 3fe999937..527363d1f 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -374,22 +374,22 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { // Inference for custom gesture classifier if it exists. if (has_custom_gesture_classifier) { ASSIGN_OR_RETURN( - auto gesture_clasification_list, + auto gesture_classification_list, GetGestureClassificationList( sub_task_model_resources.custom_gesture_classifier_model_resource, graph_options.custom_gesture_classifier_graph_options(), embedding_tensors, graph)); - gesture_clasification_list >> combine_predictions.In(classifier_nums++); + gesture_classification_list >> combine_predictions.In(classifier_nums++); } // Inference for canned gesture classifier. ASSIGN_OR_RETURN( - auto gesture_clasification_list, + auto gesture_classification_list, GetGestureClassificationList( sub_task_model_resources.canned_gesture_classifier_model_resource, graph_options.canned_gesture_classifier_graph_options(), embedding_tensors, graph)); - gesture_clasification_list >> combine_predictions.In(classifier_nums++); + gesture_classification_list >> combine_predictions.In(classifier_nums++); auto combined_classification_list = combine_predictions.Out(kPredictionTag).Cast(); diff --git a/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h b/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h index 9dfe29fd3..8cd1c6a67 100644 --- a/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h +++ b/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h @@ -29,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN * @param classObject The specified class associated with the bundle containing the file to be * loaded. * @param name Name of the image file. - * @param type Extenstion of the image file. + * @param type Extension of the image file. * * @return The `MPPImage` object contains the loaded image. This method returns * nil if it cannot load the image. @@ -46,7 +46,7 @@ NS_ASSUME_NONNULL_BEGIN * @param classObject The specified class associated with the bundle containing the file to be * loaded. * @param name Name of the image file. - * @param type Extenstion of the image file. + * @param type Extension of the image file. * @param orientation Orientation of the image. * * @return The `MPPImage` object contains the loaded image. This method returns diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index 51735ff76..155536a4e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -94,7 +94,7 @@ public class TaskRunner implements AutoCloseable { * *

Note: This method is designed for processing batch data such as unrelated images and texts. * The call blocks the current thread until a failure status or a successful result is returned. - * An internal timestamp will be assigend per invocation. This method is thread-safe and allows + * An internal timestamp will be assigned per invocation. This method is thread-safe and allows * clients to call it from different threads. * * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java index 9a52d114d..a6e246f1d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java @@ -254,7 +254,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeWithResultListener(MPImage image) { stylizeWithResultListener(image, ImageProcessingOptions.builder().build()); @@ -283,7 +283,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions) { @@ -384,7 +384,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeForVideoWithResultListener(MPImage image, long timestampMs) { stylizeForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs); @@ -411,7 +411,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeForVideoWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { diff --git a/mediapipe/tasks/web/components/processors/classifier_result.test.ts b/mediapipe/tasks/web/components/processors/classifier_result.test.ts index 4b93d0a76..2e8f9956c 100644 --- a/mediapipe/tasks/web/components/processors/classifier_result.test.ts +++ b/mediapipe/tasks/web/components/processors/classifier_result.test.ts @@ -32,12 +32,12 @@ describe('convertFromClassificationResultProto()', () => { classifcations.setHeadIndex(1); classifcations.setHeadName('headName'); const classificationList = new ClassificationList(); - const clasification = new Classification(); - clasification.setIndex(2); - clasification.setScore(0.3); - clasification.setDisplayName('displayName'); - clasification.setLabel('categoryName'); - classificationList.addClassification(clasification); + const classification = new Classification(); + classification.setIndex(2); + classification.setScore(0.3); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); @@ -62,8 +62,8 @@ describe('convertFromClassificationResultProto()', () => { const classificationResult = new ClassificationResult(); const classifcations = new Classifications(); const classificationList = new ClassificationList(); - const clasification = new Classification(); - classificationList.addClassification(clasification); + const classification = new Classification(); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); From 5a1a9269e6cd8a6a0167b43d0ba17327362bf13f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 22:29:23 -0700 Subject: [PATCH 57/63] Internal Changes PiperOrigin-RevId: 522247775 --- mediapipe/model_maker/python/core/utils/file_util.py | 2 ++ .../gesture_recognizer/gesture_recognizer_test.py | 10 ++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index 7871d90cb..221df94fd 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -94,4 +94,6 @@ class DownloadedFiles: pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) with open(absolute_path, 'wb') as f: f.write(r.content) + else: + print(f'Using existing files at {absolute_path}') return str(absolute_path) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index ad2f211f5..11b4f9759 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -15,7 +15,6 @@ import io import os import tempfile -import unittest from unittest import mock as unittest_mock import zipfile @@ -32,7 +31,6 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat tf.keras.backend.experimental.enable_tf_random_generator() -@unittest.skip('b/273818271') class GestureRecognizerTest(tf.test.TestCase): def _load_data(self): @@ -47,9 +45,6 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() tf.keras.utils.set_random_seed(87654321) - all_data = self._load_data() - # Splits data, 90% data for training, 10% for validation - self._train_data, self._validation_data = all_data.split(0.9) # Mock tempfile.gettempdir() to be unique for each test to avoid race # condition when downloading model since these tests may run in parallel. mock_gettempdir = unittest_mock.patch.object( @@ -60,6 +55,10 @@ class GestureRecognizerTest(tf.test.TestCase): ) self.mock_gettempdir = mock_gettempdir.start() self.addCleanup(mock_gettempdir.stop) + # Load dataset used by tests + all_data = self._load_data() + # Splits data, 90% data for training, 10% for validation + self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): mo = gesture_recognizer.ModelOptions() @@ -74,7 +73,6 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) - @unittest.skip('b/273818271') @unittest_mock.patch.object( tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense ) From 0067a1b5c233895ced9579efdf3ed02c5e2fb338 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 22:35:14 -0700 Subject: [PATCH 58/63] Internal changes PiperOrigin-RevId: 522248624 --- .../python/vision/object_detector/object_detector_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index 3feb75f2e..df6b58a07 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -14,7 +14,6 @@ import os import tempfile -import unittest # pylint:disable=unused-import from unittest import mock as unittest_mock from absl.testing import parameterized @@ -28,7 +27,6 @@ from mediapipe.model_maker.python.vision.object_detector import object_detector_ from mediapipe.tasks.python.test import test_utils as task_test_utils -@unittest.skip('b/275624089') class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): @@ -51,7 +49,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): def test_object_detector(self): hparams = hyperparameters.HParams( - epochs=10, + epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, @@ -75,7 +73,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): output_tflite_file = os.path.join( options.hparams.export_dir, 'model.tflite' ) - print('ASDF float', os.path.getsize(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file)) self.assertGreater(os.path.getsize(output_tflite_file), 0) self.assertTrue(os.path.exists(output_metadata_file)) @@ -85,7 +82,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): qat_hparams = hyperparameters.QATHParams( learning_rate=0.9, batch_size=2, - epochs=5, + epochs=1, decay_steps=6, decay_rate=0.96, ) @@ -101,7 +98,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): output_tflite_file = os.path.join( options.hparams.export_dir, 'model_qat.tflite' ) - print('ASDF qat', os.path.getsize(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file)) self.assertGreater(os.path.getsize(output_tflite_file), 0) self.assertLess(os.path.getsize(output_tflite_file), 3500000) From 56b3cd4350304a1d9ae1963808b39fce1860a457 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 23:12:46 -0700 Subject: [PATCH 59/63] Internal change PiperOrigin-RevId: 522253757 --- mediapipe/tasks/cc/core/external_file_handler.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index c271f3dac..a56f03d55 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -66,13 +66,13 @@ using ::absl::StatusCode; // Gets the offset aligned to page size for mapping given files into memory by // file descriptor correctly, as according to mmap(2), the offset used in mmap // must be a multiple of sysconf(_SC_PAGE_SIZE). -int64 GetPageSizeAlignedOffset(int64 offset) { +int64_t GetPageSizeAlignedOffset(int64_t offset) { #ifdef _WIN32 // mmap is not used on Windows return 0; #else - int64 aligned_offset = offset; - int64 page_size = sysconf(_SC_PAGE_SIZE); + int64_t aligned_offset = offset; + int64_t page_size = sysconf(_SC_PAGE_SIZE); if (offset % page_size != 0) { aligned_offset = offset / page_size * page_size; } From 12ecc8139fe996fe5c2cfd289fcb38d1c77f80eb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 23:22:52 -0700 Subject: [PATCH 60/63] Internal change PiperOrigin-RevId: 522255287 --- .../image_to_tensor_converter_opencv.cc | 19 ++++++++++--------- .../tensor_converter_calculator_test.cc | 8 ++++---- .../tensors_to_classification_calculator.cc | 7 ++++--- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 95e38f89c..bb4c6de79 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -92,13 +92,14 @@ class OpenCvProcessor : public ImageToTensorConverter { const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - RET_CHECK_GE(output_shape.num_elements(), - tensor_buffer_offset / sizeof(int8) + num_elements_per_img) + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(int8_t) + num_elements_per_img) << "The buffer offset + the input image size is larger than the " "allocated tensor buffer."; - dst = cv::Mat( - output_height, output_width, dst_data_type, - buffer_view.buffer() + tensor_buffer_offset / sizeof(int8)); + dst = cv::Mat(output_height, output_width, dst_data_type, + buffer_view.buffer() + + tensor_buffer_offset / sizeof(int8_t)); break; case Tensor::ElementType::kFloat32: RET_CHECK_GE( @@ -113,12 +114,12 @@ class OpenCvProcessor : public ImageToTensorConverter { case Tensor::ElementType::kUInt8: RET_CHECK_GE( output_shape.num_elements(), - tensor_buffer_offset / sizeof(uint8) + num_elements_per_img) + tensor_buffer_offset / sizeof(uint8_t) + num_elements_per_img) << "The buffer offset + the input image size is larger than the " "allocated tensor buffer."; - dst = cv::Mat( - output_height, output_width, dst_data_type, - buffer_view.buffer() + tensor_buffer_offset / sizeof(uint8)); + dst = cv::Mat(output_height, output_width, dst_data_type, + buffer_view.buffer() + + tensor_buffer_offset / sizeof(uint8_t)); break; default: return InvalidArgumentError( diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc index bdea0795e..2cfbd3d1e 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc @@ -41,7 +41,7 @@ constexpr char kTransposeOptionsString[] = using RandomEngine = std::mt19937_64; using testing::Eq; -const uint32 kSeed = 1234; +const uint32_t kSeed = 1234; const int kNumSizes = 8; const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, {5, 3}, {7, 13}, {16, 32}, {101, 2}}; @@ -49,7 +49,7 @@ const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, class TensorConverterCalculatorTest : public ::testing::Test { protected: // Adds a packet with a matrix filled with random values in [0,1]. - void AddRandomMatrix(int num_rows, int num_columns, uint32 seed, + void AddRandomMatrix(int num_rows, int num_columns, uint32_t seed, bool row_major_matrix = false) { RandomEngine random(kSeed); std::uniform_real_distribution<> uniform_dist(0, 1.0); @@ -229,7 +229,7 @@ TEST_F(TensorConverterCalculatorTest, CustomDivAndSub) { MP_ASSERT_OK(graph.StartRun({})); auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); cv::Mat mat = mediapipe::formats::MatView(input_image.get()); - mat.at(0, 0) = 200; + mat.at(0, 0) = 200; MP_ASSERT_OK(graph.AddPacketToInputStream( "input_image", Adopt(input_image.release()).At(Timestamp(0)))); @@ -286,7 +286,7 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) { MP_ASSERT_OK(graph.StartRun({})); auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); cv::Mat mat = mediapipe::formats::MatView(input_image.get()); - mat.at(0, 0) = 200; + mat.at(0, 0) = 200; MP_ASSERT_OK(graph.AddPacketToInputStream( "input_image", Adopt(input_image.release()).At(Timestamp(0)))); diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index 5bfc00ed7..7041c02e4 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -84,7 +84,7 @@ class TensorsToClassificationCalculator : public Node { private: int top_k_ = 0; bool sort_by_descending_score_ = false; - proto_ns::Map local_label_map_; + proto_ns::Map local_label_map_; bool label_map_loaded_ = false; bool is_binary_classification_ = false; float min_score_threshold_ = std::numeric_limits::lowest(); @@ -98,7 +98,8 @@ class TensorsToClassificationCalculator : public Node { // These are used to filter out the output classification results. ClassIndexSet class_index_set_; bool IsClassIndexAllowed(int class_index); - const proto_ns::Map& GetLabelMap(CalculatorContext* cc); + const proto_ns::Map& GetLabelMap( + CalculatorContext* cc); }; MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator); @@ -252,7 +253,7 @@ bool TensorsToClassificationCalculator::IsClassIndexAllowed(int class_index) { } } -const proto_ns::Map& +const proto_ns::Map& TensorsToClassificationCalculator::GetLabelMap(CalculatorContext* cc) { return !local_label_map_.empty() ? local_label_map_ From 56552dbfb5818b6926d5b1c23f1f0e718a624623 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 23:23:29 -0700 Subject: [PATCH 61/63] Internal change PiperOrigin-RevId: 522255364 --- .../framework/calculator_graph_bounds_test.cc | 2 +- .../calculator_graph_side_packet_test.cc | 12 +++---- mediapipe/framework/calculator_graph_test.cc | 36 +++++++++---------- mediapipe/framework/calculator_runner.cc | 2 +- mediapipe/framework/counter_factory.cc | 8 ++--- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index d149337cc..81ce9902c 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -679,7 +679,7 @@ REGISTER_CALCULATOR(BoundToPacketCalculator); // A Calculator that produces packets at timestamps beyond the input timestamp. class FuturePacketCalculator : public CalculatorBase { public: - static constexpr int64 kOutputFutureMicros = 3; + static constexpr int64_t kOutputFutureMicros = 3; static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); diff --git a/mediapipe/framework/calculator_graph_side_packet_test.cc b/mediapipe/framework/calculator_graph_side_packet_test.cc index 57fcff866..a9567c805 100644 --- a/mediapipe/framework/calculator_graph_side_packet_test.cc +++ b/mediapipe/framework/calculator_graph_side_packet_test.cc @@ -188,21 +188,21 @@ class Uint64PacketGenerator : public PacketGenerator { static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { - output_side_packets->Index(0).Set(); + output_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { - output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); + output_side_packets->Index(0) = Adopt(new uint64_t(15LL << 32 | 5)); return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(Uint64PacketGenerator); TEST(CalculatorGraph, OutputSidePacketInProcess) { - const int64 offset = 100; + const int64_t offset = 100; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "offset" @@ -400,7 +400,7 @@ TEST(CalculatorGraph, SharePacketGeneratorGraph) { } TEST(CalculatorGraph, OutputSidePacketAlreadySet) { - const int64 offset = 100; + const int64_t offset = 100; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "offset" @@ -427,7 +427,7 @@ TEST(CalculatorGraph, OutputSidePacketAlreadySet) { } TEST(CalculatorGraph, OutputSidePacketWithTimestamp) { - const int64 offset = 100; + const int64_t offset = 100; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "offset" @@ -716,7 +716,7 @@ TEST(CalculatorGraph, GetOutputSidePacket) { // Run the graph twice. int max_count = 100; std::map extra_side_packets; - extra_side_packets.insert({"input_uint64", MakePacket(1123)}); + extra_side_packets.insert({"input_uint64", MakePacket(1123)}); for (int run = 0; run < 1; ++run) { MP_ASSERT_OK(graph.StartRun(extra_side_packets)); status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 6ca206ab1..2e7d99ef6 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -439,7 +439,7 @@ class GlobalCountSourceCalculator : public CalculatorBase { ++local_count_; } - int64 local_count_ = 0; + int64_t local_count_ = 0; }; const int GlobalCountSourceCalculator::kNumOutputPackets = 5; REGISTER_CALCULATOR(GlobalCountSourceCalculator); @@ -765,7 +765,7 @@ class TypedStatusHandler : public StatusHandler { } }; typedef TypedStatusHandler StringStatusHandler; -typedef TypedStatusHandler Uint32StatusHandler; +typedef TypedStatusHandler Uint32StatusHandler; REGISTER_STATUS_HANDLER(StringStatusHandler); REGISTER_STATUS_HANDLER(Uint32StatusHandler); @@ -1398,9 +1398,9 @@ void RunComprehensiveTest(CalculatorGraph* graph, MP_ASSERT_OK(graph->Initialize(proto)); std::map extra_side_packets; - extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 3))); + extra_side_packets.emplace("node_3", Adopt(new uint64_t((15LL << 32) | 3))); if (define_node_5) { - extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5))); + extra_side_packets.emplace("node_5", Adopt(new uint64_t((15LL << 32) | 5))); } // Call graph->Run() several times, to make sure that the appropriate @@ -1452,9 +1452,9 @@ void RunComprehensiveTest(CalculatorGraph* graph, // Verify that the graph can still run (but not successfully) when // one of the nodes is caused to fail. extra_side_packets.clear(); - extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 0))); + extra_side_packets.emplace("node_3", Adopt(new uint64_t((15LL << 32) | 0))); if (define_node_5) { - extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5))); + extra_side_packets.emplace("node_5", Adopt(new uint64_t((15LL << 32) | 5))); } dumped_final_sum_packet = Packet(); dumped_final_stddev_packet = Packet(); @@ -1579,14 +1579,14 @@ class Uint64PacketGenerator : public PacketGenerator { static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { - output_side_packets->Index(0).Set(); + output_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { - output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); + output_side_packets->Index(0) = Adopt(new uint64_t(15LL << 32 | 5)); return absl::OkStatus(); } }; @@ -1759,7 +1759,7 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { )pb"); MP_ASSERT_OK(graph->Initialize(config)); Packet extra_string = Adopt(new std::string("foo")); - Packet a_uint64 = Adopt(new uint64(0)); + Packet a_uint64 = Adopt(new uint64_t(0)); MP_EXPECT_OK( graph->Run({{"extra_string", extra_string}, {"a_uint64", a_uint64}})); @@ -1789,7 +1789,7 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { testing::HasSubstr("string"), // Expected type. testing::HasSubstr( - MediaPipeTypeStringOrDemangled()))); + MediaPipeTypeStringOrDemangled()))); // Should fail verification when the type of a to-be-generated packet is // wrong. The added handler now expects a string but will receive the uint32 @@ -1802,14 +1802,14 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { status = graph->Initialize(config); EXPECT_THAT(status.message(), - testing::AllOf( - testing::HasSubstr("StringStatusHandler"), - // The problematic input side packet. - testing::HasSubstr("generated_by_generator"), - // Actual type. - testing::HasSubstr(MediaPipeTypeStringOrDemangled()), - // Expected type. - testing::HasSubstr("string"))); + testing::AllOf(testing::HasSubstr("StringStatusHandler"), + // The problematic input side packet. + testing::HasSubstr("generated_by_generator"), + // Actual type. + testing::HasSubstr( + MediaPipeTypeStringOrDemangled()), + // Expected type. + testing::HasSubstr("string"))); } TEST(CalculatorGraph, GenerateInInitialize) { diff --git a/mediapipe/framework/calculator_runner.cc b/mediapipe/framework/calculator_runner.cc index 833797483..1bd3211ed 100644 --- a/mediapipe/framework/calculator_runner.cc +++ b/mediapipe/framework/calculator_runner.cc @@ -216,7 +216,7 @@ mediapipe::Counter* CalculatorRunner::GetCounter(const std::string& name) { return graph_->GetCounterFactory()->GetCounter(name); } -std::map CalculatorRunner::GetCountersValues() { +std::map CalculatorRunner::GetCountersValues() { return graph_->GetCounterFactory()->GetCounterSet()->GetCountersValues(); } diff --git a/mediapipe/framework/counter_factory.cc b/mediapipe/framework/counter_factory.cc index 94a6a4213..895b44ea6 100644 --- a/mediapipe/framework/counter_factory.cc +++ b/mediapipe/framework/counter_factory.cc @@ -39,14 +39,14 @@ class BasicCounter : public Counter { value_ += amount; } - int64 Get() ABSL_LOCKS_EXCLUDED(mu_) override { + int64_t Get() ABSL_LOCKS_EXCLUDED(mu_) override { absl::ReaderMutexLock lock(&mu_); return value_; } private: absl::Mutex mu_; - int64 value_ ABSL_GUARDED_BY(mu_); + int64_t value_ ABSL_GUARDED_BY(mu_); }; } // namespace @@ -73,10 +73,10 @@ Counter* CounterSet::Get(const std::string& name) ABSL_LOCKS_EXCLUDED(mu_) { return counters_[name].get(); } -std::map CounterSet::GetCountersValues() +std::map CounterSet::GetCountersValues() ABSL_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); - std::map result; + std::map result; for (const auto& it : counters_) { result[it.first] = it.second->Get(); } From d05508cb7bbdad4064ddaadf03bb14d10c3a1904 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 23:58:33 -0700 Subject: [PATCH 62/63] Internal change PiperOrigin-RevId: 522260226 --- mediapipe/framework/scheduler_queue.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/framework/scheduler_queue.cc b/mediapipe/framework/scheduler_queue.cc index efad97282..33214cf64 100644 --- a/mediapipe/framework/scheduler_queue.cc +++ b/mediapipe/framework/scheduler_queue.cc @@ -240,7 +240,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, // we should not run any more sources. Close the node if it is a source. if (shared_->stopping && node->IsSource()) { VLOG(4) << "Closing " << node->DebugName() << " due to StatusStop()."; - int64 start_time = shared_->timer.StartNode(); + int64_t start_time = shared_->timer.StartNode(); // It's OK to not reset/release the prepared CalculatorContext since a // source node always reuses the same CalculatorContext and Close() doesn't // access any inputs. @@ -256,7 +256,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, } else { // Note that we don't need a lock because only one thread can execute this // due to the lock on running_nodes. - int64 start_time = shared_->timer.StartNode(); + int64_t start_time = shared_->timer.StartNode(); const absl::Status result = node->ProcessNode(cc); shared_->timer.EndNode(start_time); @@ -283,7 +283,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, void SchedulerQueue::OpenCalculatorNode(CalculatorNode* node) { VLOG(3) << "Opening " << node->DebugName(); - int64 start_time = shared_->timer.StartNode(); + int64_t start_time = shared_->timer.StartNode(); const absl::Status result = node->OpenNode(); shared_->timer.EndNode(start_time); if (!result.ok()) { From 22186299c485135e9298c7cdba52580bf9dadee2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 6 Apr 2023 00:19:23 -0700 Subject: [PATCH 63/63] Internal change PiperOrigin-RevId: 522263621 --- .../util/clock_latency_calculator.cc | 8 ++--- ...collection_has_min_size_calculator_test.cc | 2 +- .../detection_label_id_to_text_calculator.cc | 9 ++--- ...ction_letterbox_removal_calculator_test.cc | 2 +- ...etection_transformation_calculator_test.cc | 4 +-- .../util/detection_unique_id_calculator.cc | 2 +- .../detections_to_rects_calculator_test.cc | 4 +-- ...tections_to_render_data_calculator_test.cc | 6 ++-- .../util/filter_collection_calculator.cc | 2 +- .../calculators/util/from_image_calculator.cc | 4 +-- .../util/packet_frequency_calculator.cc | 18 +++++----- .../util/packet_latency_calculator.cc | 36 +++++++++---------- .../util/packet_latency_calculator_test.cc | 6 ++-- 13 files changed, 53 insertions(+), 50 deletions(-) diff --git a/mediapipe/calculators/util/clock_latency_calculator.cc b/mediapipe/calculators/util/clock_latency_calculator.cc index 5c5711731..beaa41e66 100644 --- a/mediapipe/calculators/util/clock_latency_calculator.cc +++ b/mediapipe/calculators/util/clock_latency_calculator.cc @@ -66,17 +66,17 @@ class ClockLatencyCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override; private: - int64 num_packet_streams_ = -1; + int64_t num_packet_streams_ = -1; }; REGISTER_CALCULATOR(ClockLatencyCalculator); absl::Status ClockLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); - int64 num_packet_streams = cc->Inputs().NumEntries() - 1; + int64_t num_packet_streams = cc->Inputs().NumEntries() - 1; RET_CHECK_EQ(cc->Outputs().NumEntries(), num_packet_streams); - for (int64 i = 0; i < num_packet_streams; ++i) { + for (int64_t i = 0; i < num_packet_streams; ++i) { cc->Inputs().Index(i).Set(); cc->Outputs().Index(i).Set(); } @@ -99,7 +99,7 @@ absl::Status ClockLatencyCalculator::Process(CalculatorContext* cc) { cc->Inputs().Tag(kReferenceTag).Get(); // Push Duration packets for every input stream we have. - for (int64 i = 0; i < num_packet_streams_; ++i) { + for (int64_t i = 0; i < num_packet_streams_; ++i) { if (!cc->Inputs().Index(i).IsEmpty()) { const absl::Time& input_stream_time = cc->Inputs().Index(i).Get(); diff --git a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc index 62eb1d8ae..71cba9430 100644 --- a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc +++ b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc @@ -33,7 +33,7 @@ typedef CollectionHasMinSizeCalculator> TestIntCollectionHasMinSizeCalculator; REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator); -void AddInputVector(const std::vector& input, int64 timestamp, +void AddInputVector(const std::vector& input, int64_t timestamp, CalculatorRunner* runner) { runner->MutableInputs() ->Tag(kIterableTag) diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 0b8dde20d..0c1d6892e 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -57,9 +57,10 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase { private: // Local label map built from the calculator options' `label_map_path` or // `label` field. - proto_ns::Map local_label_map_; + proto_ns::Map local_label_map_; bool keep_label_id_; - const proto_ns::Map& GetLabelMap(CalculatorContext* cc); + const proto_ns::Map& GetLabelMap( + CalculatorContext* cc); }; REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); @@ -115,7 +116,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { output_detections.push_back(input_detection); Detection& output_detection = output_detections.back(); bool has_text_label = false; - for (const int32 label_id : output_detection.label_id()) { + for (const int32_t label_id : output_detection.label_id()) { if (GetLabelMap(cc).contains(label_id)) { auto item = GetLabelMap(cc).at(label_id); output_detection.add_label(item.name()); @@ -136,7 +137,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -const proto_ns::Map& +const proto_ns::Map& DetectionLabelIdToTextCalculator::GetLabelMap(CalculatorContext* cc) { return !local_label_map_.empty() ? local_label_map_ diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc index c4f084363..75dd93cc3 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc @@ -40,7 +40,7 @@ LocationData CreateRelativeLocationData(double xmin, double ymin, double width, } Detection CreateDetection(const std::vector& labels, - const std::vector& label_ids, + const std::vector& label_ids, const std::vector& scores, const LocationData& location_data, const std::string& feature_tag) { diff --git a/mediapipe/calculators/util/detection_transformation_calculator_test.cc b/mediapipe/calculators/util/detection_transformation_calculator_test.cc index e280b5153..30d1bc64b 100644 --- a/mediapipe/calculators/util/detection_transformation_calculator_test.cc +++ b/mediapipe/calculators/util/detection_transformation_calculator_test.cc @@ -39,8 +39,8 @@ constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; constexpr char kRelativeDetectionListTag[] = "RELATIVE_DETECTION_LIST"; constexpr char kRelativeDetectionsTag[] = "RELATIVE_DETECTIONS"; -Detection DetectionWithBoundingBox(int32 xmin, int32 ymin, int32 width, - int32 height) { +Detection DetectionWithBoundingBox(int32_t xmin, int32_t ymin, int32_t width, + int32_t height) { Detection detection; LocationData* location_data = detection.mutable_location_data(); location_data->set_format(LocationData::BOUNDING_BOX); diff --git a/mediapipe/calculators/util/detection_unique_id_calculator.cc b/mediapipe/calculators/util/detection_unique_id_calculator.cc index ac8889ffb..d5b1cffa3 100644 --- a/mediapipe/calculators/util/detection_unique_id_calculator.cc +++ b/mediapipe/calculators/util/detection_unique_id_calculator.cc @@ -26,7 +26,7 @@ constexpr char kDetectionListTag[] = "DETECTION_LIST"; // Each detection processed by DetectionUniqueIDCalculator will be assigned an // unique id that starts from 1. If a detection already has an ID other than 0, // the ID will be overwritten. -static int64 detection_id = 0; +static int64_t detection_id = 0; inline int GetNextDetectionId() { return ++detection_id; } diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 63de60a60..95e18e90c 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -56,8 +56,8 @@ MATCHER_P4(NormRectEq, x_center, y_center, width, height, "") { testing::Value(arg.height(), testing::FloatEq(height)); } -Detection DetectionWithLocationData(int32 xmin, int32 ymin, int32 width, - int32 height) { +Detection DetectionWithLocationData(int32_t xmin, int32_t ymin, int32_t width, + int32_t height) { Detection detection; LocationData* location_data = detection.mutable_location_data(); location_data->set_format(LocationData::BOUNDING_BOX); diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc index 0d0da2350..6da8c449a 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc @@ -43,8 +43,8 @@ void VerifyRenderAnnotationColorThickness( EXPECT_EQ(annotation.thickness(), options.thickness()); } -LocationData CreateLocationData(int32 xmin, int32 ymin, int32 width, - int32 height) { +LocationData CreateLocationData(int32_t xmin, int32_t ymin, int32_t width, + int32_t height) { LocationData location_data; location_data.set_format(LocationData::BOUNDING_BOX); location_data.mutable_bounding_box()->set_xmin(xmin); @@ -66,7 +66,7 @@ LocationData CreateRelativeLocationData(double xmin, double ymin, double width, } Detection CreateDetection(const std::vector& labels, - const std::vector& label_ids, + const std::vector& label_ids, const std::vector& scores, const LocationData& location_data, const std::string& feature_tag) { diff --git a/mediapipe/calculators/util/filter_collection_calculator.cc b/mediapipe/calculators/util/filter_collection_calculator.cc index ab361f450..2cf41ead8 100644 --- a/mediapipe/calculators/util/filter_collection_calculator.cc +++ b/mediapipe/calculators/util/filter_collection_calculator.cc @@ -24,7 +24,7 @@ namespace mediapipe { -typedef FilterCollectionCalculator> +typedef FilterCollectionCalculator> FilterUInt64CollectionCalculator; REGISTER_CALCULATOR(FilterUInt64CollectionCalculator); diff --git a/mediapipe/calculators/util/from_image_calculator.cc b/mediapipe/calculators/util/from_image_calculator.cc index 0ddb342eb..706f8727b 100644 --- a/mediapipe/calculators/util/from_image_calculator.cc +++ b/mediapipe/calculators/util/from_image_calculator.cc @@ -163,8 +163,8 @@ absl::Status FromImageCalculator::Process(CalculatorContext* cc) { std::unique_ptr output = std::make_unique( input.image_format(), input.width(), input.height(), input.step(), - const_cast(input.GetImageFrameSharedPtr()->PixelData()), - [packet_copy_ptr](uint8*) { delete packet_copy_ptr; }); + const_cast(input.GetImageFrameSharedPtr()->PixelData()), + [packet_copy_ptr](uint8_t*) { delete packet_copy_ptr; }); cc->Outputs() .Tag(kImageFrameTag) .Add(output.release(), cc->InputTimestamp()); diff --git a/mediapipe/calculators/util/packet_frequency_calculator.cc b/mediapipe/calculators/util/packet_frequency_calculator.cc index 19ffae70e..f9c28f5ff 100644 --- a/mediapipe/calculators/util/packet_frequency_calculator.cc +++ b/mediapipe/calculators/util/packet_frequency_calculator.cc @@ -84,23 +84,24 @@ class PacketFrequencyCalculator : public CalculatorBase { const Timestamp& input_timestamp); // Adds the input timestamp in the particular stream's timestamp buffer. - absl::Status AddPacketTimestampForStream(int stream_id, int64 timestamp); + absl::Status AddPacketTimestampForStream(int stream_id, int64_t timestamp); // For the specified input stream, clears timestamps from buffer that are // older than the configured time_window_sec. - absl::Status ClearOldpacketTimestamps(int stream_id, int64 current_timestamp); + absl::Status ClearOldpacketTimestamps(int stream_id, + int64_t current_timestamp); // Options for the calculator. PacketFrequencyCalculatorOptions options_; // Map where key is the input stream ID and value is the timestamp of the // first packet received on that stream. - std::map first_timestamp_for_stream_id_usec_; + std::map first_timestamp_for_stream_id_usec_; // Map where key is the input stream ID and value is a vector that stores // timestamps of recently received packets on the stream. Timestamps older // than the time_window_sec are continuously deleted for all the streams. - std::map> previous_timestamps_for_stream_id_; + std::map> previous_timestamps_for_stream_id_; }; REGISTER_CALCULATOR(PacketFrequencyCalculator); @@ -166,7 +167,7 @@ absl::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { } absl::Status PacketFrequencyCalculator::AddPacketTimestampForStream( - int stream_id, int64 timestamp_usec) { + int stream_id, int64_t timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { return absl::InvalidArgumentError("Input stream id is invalid"); @@ -178,19 +179,20 @@ absl::Status PacketFrequencyCalculator::AddPacketTimestampForStream( } absl::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( - int stream_id, int64 current_timestamp_usec) { + int stream_id, int64_t current_timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { return absl::InvalidArgumentError("Input stream id is invalid"); } auto& timestamps_buffer = previous_timestamps_for_stream_id_[stream_id]; - int64 time_window_usec = options_.time_window_sec() * kSecondsToMicroseconds; + int64_t time_window_usec = + options_.time_window_sec() * kSecondsToMicroseconds; timestamps_buffer.erase( std::remove_if(timestamps_buffer.begin(), timestamps_buffer.end(), [&time_window_usec, - ¤t_timestamp_usec](const int64 timestamp_usec) { + ¤t_timestamp_usec](const int64_t timestamp_usec) { return current_timestamp_usec - timestamp_usec > time_window_usec; }), diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc index 0e5b2e885..6509f016f 100644 --- a/mediapipe/calculators/util/packet_latency_calculator.cc +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -118,24 +118,24 @@ class PacketLatencyCalculator : public CalculatorBase { std::shared_ptr<::mediapipe::Clock> clock_; // Clock time when the first reference packet was received. - int64 first_process_time_usec_ = -1; + int64_t first_process_time_usec_ = -1; // Timestamp of the first reference packet received. - int64 first_reference_timestamp_usec_ = -1; + int64_t first_reference_timestamp_usec_ = -1; // Number of packet streams. - int64 num_packet_streams_ = -1; + int64_t num_packet_streams_ = -1; // Latency output for each packet stream. std::vector packet_latencies_; // Running sum and count of latencies for each packet stream. This is required // to compute the average latency. - std::vector sum_latencies_usec_; - std::vector num_latencies_; + std::vector sum_latencies_usec_; + std::vector num_latencies_; // Clock time when last reset was done for histogram and running average. - int64 last_reset_time_usec_ = -1; + int64_t last_reset_time_usec_ = -1; }; REGISTER_CALCULATOR(PacketLatencyCalculator); @@ -143,9 +143,9 @@ absl::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); // Input and output streams. - int64 num_packet_streams = cc->Inputs().NumEntries() - 1; + int64_t num_packet_streams = cc->Inputs().NumEntries() - 1; RET_CHECK_EQ(cc->Outputs().NumEntries(), num_packet_streams); - for (int64 i = 0; i < num_packet_streams; ++i) { + for (int64_t i = 0; i < num_packet_streams; ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).Set(); } @@ -165,8 +165,8 @@ absl::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { void PacketLatencyCalculator::ResetStatistics() { // Initialize histogram with zero counts and set running average to zero. - for (int64 i = 0; i < num_packet_streams_; ++i) { - for (int64 interval_index = 0; interval_index < options_.num_intervals(); + for (int64_t i = 0; i < num_packet_streams_; ++i) { + for (int64_t interval_index = 0; interval_index < options_.num_intervals(); ++interval_index) { packet_latencies_[i].set_counts(interval_index, 0); } @@ -196,7 +196,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { packet_latencies_.resize(num_packet_streams_); sum_latencies_usec_.resize(num_packet_streams_); num_latencies_.resize(num_packet_streams_); - for (int64 i = 0; i < num_packet_streams_; ++i) { + for (int64_t i = 0; i < num_packet_streams_; ++i) { // Initialize latency histograms with zero counts. packet_latencies_[i].set_num_intervals(options_.num_intervals()); packet_latencies_[i].set_interval_size_usec(options_.interval_size_usec()); @@ -208,7 +208,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { if (labels_provided) { packet_latencies_[i].set_label(options_.packet_labels(i)); } else { - int64 input_stream_index = cc->Inputs().TagMap()->GetId("", i).value(); + int64_t input_stream_index = cc->Inputs().TagMap()->GetId("", i).value(); packet_latencies_[i].set_label( cc->Inputs().TagMap()->Names()[input_stream_index]); } @@ -242,7 +242,7 @@ absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { } if (options_.reset_duration_usec() > 0) { - const int64 time_now_usec = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t time_now_usec = absl::ToUnixMicros(clock_->TimeNow()); if (time_now_usec - last_reset_time_usec_ >= options_.reset_duration_usec()) { ResetStatistics(); @@ -251,16 +251,16 @@ absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { } // Update latency info if there is any incoming packet. - for (int64 i = 0; i < num_packet_streams_; ++i) { + for (int64_t i = 0; i < num_packet_streams_; ++i) { if (!cc->Inputs().Index(i).IsEmpty()) { const auto& packet_timestamp_usec = cc->InputTimestamp().Value(); // Update latency statistics for this stream. - int64 current_clock_time_usec = absl::ToUnixMicros(clock_->TimeNow()); - int64 current_calibrated_timestamp_usec = + int64_t current_clock_time_usec = absl::ToUnixMicros(clock_->TimeNow()); + int64_t current_calibrated_timestamp_usec = (current_clock_time_usec - first_process_time_usec_) + first_reference_timestamp_usec_; - int64 packet_latency_usec = + int64_t packet_latency_usec = current_calibrated_timestamp_usec - packet_timestamp_usec; // Invalid timestamps in input signals could result in negative latencies. @@ -270,7 +270,7 @@ absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { // Update the latency, running average and histogram for this stream. packet_latencies_[i].set_current_latency_usec(packet_latency_usec); - int64 interval_index = + int64_t interval_index = packet_latency_usec / packet_latencies_[i].interval_size_usec(); if (interval_index >= packet_latencies_[i].num_intervals()) { interval_index = packet_latencies_[i].num_intervals() - 1; diff --git a/mediapipe/calculators/util/packet_latency_calculator_test.cc b/mediapipe/calculators/util/packet_latency_calculator_test.cc index 6f03f2e75..d323a14f9 100644 --- a/mediapipe/calculators/util/packet_latency_calculator_test.cc +++ b/mediapipe/calculators/util/packet_latency_calculator_test.cc @@ -169,10 +169,10 @@ class PacketLatencyCalculatorTest : public ::testing::Test { } PacketLatency CreatePacketLatency(const double latency_usec, - const int64 num_intervals, - const int64 interval_size_usec, + const int64_t num_intervals, + const int64_t interval_size_usec, const std::vector& counts, - const int64 avg_latency_usec, + const int64_t avg_latency_usec, const std::string& label) { PacketLatency latency_info; latency_info.set_current_latency_usec(latency_usec);