diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 9c4b3af7c..74398be42 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -328,6 +328,7 @@ cc_library( ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -344,6 +345,7 @@ cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", diff --git a/mediapipe/calculators/core/concatenate_proto_list_calculator.cc b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc index 9dd0dfd99..6c58e1110 100644 --- a/mediapipe/calculators/core/concatenate_proto_list_calculator.cc +++ b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc @@ -18,6 +18,7 @@ #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" @@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator }; MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator); +class ConcatenateClassificationListCalculator + : public ConcatenateListsCalculator { + protected: + int ListSize(const ClassificationList& list) const override { + return list.classification_size(); + } + const Classification GetItem(const ClassificationList& list, + int idx) const override { + return list.classification(idx); + } + Classification* AddItem(ClassificationList& list) const override { + return list.add_classification(); + } +}; +MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc b/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc index fd116ece7..2167cd9d1 100644 --- a/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc +++ b/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -70,6 +71,16 @@ void AddInputLandmarkLists( } } +void AddInputClassificationLists( + const std::vector& input_classifications_vec, + int64 timestamp, CalculatorRunner* runner) { + for (int i = 0; i < input_classifications_vec.size(); ++i) { + runner->MutableInputs()->Index(i).packets.push_back( + MakePacket(input_classifications_vec[i]) + .At(Timestamp(timestamp))); + } +} + TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) { CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator", /*options_string=*/"", /*num_inputs=*/3, @@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) { EXPECT_EQ(0, outputs.size()); } +TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) { + CalculatorRunner runner("ConcatenateClassificationListCalculator", + /*options_string=*/ + "[mediapipe.ConcatenateVectorCalculatorOptions.ext]: " + "{only_emit_if_all_present: true}", + /*num_inputs=*/2, + /*num_outputs=*/1, /*num_side_packets=*/0); + + auto input_0 = ParseTextProtoOrDie(R"pb( + classification: { index: 0 score: 0.2 label: "test_0" } + classification: { index: 1 score: 0.3 label: "test_1" } + classification: { index: 2 score: 0.4 label: "test_2" } + )pb"); + auto input_1 = ParseTextProtoOrDie(R"pb( + classification: { index: 3 score: 0.2 label: "test_3" } + classification: { index: 4 score: 0.3 label: "test_4" } + )pb"); + std::vector inputs = {input_0, input_1}; + AddInputClassificationLists(inputs, /*timestamp=*/1, &runner); + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + auto result = outputs[0].Get(); + EXPECT_THAT(ParseTextProtoOrDie(R"pb( + classification: { index: 0 score: 0.2 label: "test_0" } + classification: { index: 1 score: 0.3 label: "test_1" } + classification: { index: 2 score: 0.4 label: "test_2" } + classification: { index: 3 score: 0.2 label: "test_3" } + classification: { index: 4 score: 0.3 label: "test_4" } + )pb"), + EqualsProto(result)); +} + } // namespace mediapipe diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD new file mode 100644 index 000000000..d9f169c65 --- /dev/null +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -0,0 +1,37 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "audio_task_running_mode", + srcs = ["audio_task_running_mode.py"], +) + +py_library( + name = "base_audio_task_api", + srcs = [ + "base_audio_task_api.py", + ], + deps = [ + ":audio_task_running_mode", + "//mediapipe/framework:calculator_py_pb2", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/audio/core/__init__.py b/mediapipe/tasks/python/audio/core/__init__.py new file mode 100644 index 000000000..6a8405189 --- /dev/null +++ b/mediapipe/tasks/python/audio/core/__init__.py @@ -0,0 +1,16 @@ +"""Copyright 2022 The MediaPipe Authors. + +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/mediapipe/tasks/python/audio/core/audio_task_running_mode.py b/mediapipe/tasks/python/audio/core/audio_task_running_mode.py new file mode 100644 index 000000000..0fa36d40e --- /dev/null +++ b/mediapipe/tasks/python/audio/core/audio_task_running_mode.py @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The running mode of MediaPipe Audio Tasks.""" + +import enum + + +class AudioTaskRunningMode(enum.Enum): + """MediaPipe audio task running mode. + + Attributes: + AUDIO_CLIPS: The mode for running a mediapipe audio task on independent + audio clips. + AUDIO_STREAM: The mode for running a mediapipe audio task on an audio + stream, such as from microphone. + """ + AUDIO_CLIPS = 'AUDIO_CLIPS' + AUDIO_STREAM = 'AUDIO_STREAM' diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py new file mode 100644 index 000000000..b6a2e0e40 --- /dev/null +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -0,0 +1,123 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe audio task base api.""" + +from typing import Callable, Mapping, Optional + +from mediapipe.framework import calculator_pb2 +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.python._framework_bindings import task_runner as task_runner_module +from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_TaskRunner = task_runner_module.TaskRunner +_Packet = packet_module.Packet +_RunningMode = running_mode_module.AudioTaskRunningMode + + +class BaseAudioTaskApi(object): + """The base class of the user-facing mediapipe audio task api classes.""" + + def __init__( + self, + graph_config: calculator_pb2.CalculatorGraphConfig, + running_mode: _RunningMode, + packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]], + None]] = None + ) -> None: + """Initializes the `BaseAudioTaskApi` object. + + Args: + graph_config: The mediapipe audio task graph config proto. + running_mode: The running mode of the mediapipe audio task. + packet_callback: The optional packet callback for getting results + asynchronously in the audio stream mode. + + Raises: + ValueError: The packet callback is not properly set based on the task's + running mode. + """ + if running_mode == _RunningMode.AUDIO_STREAM: + if packet_callback is None: + raise ValueError( + 'The audio task is in audio stream mode, a user-defined result ' + 'callback must be provided.') + elif packet_callback: + raise ValueError( + 'The audio task is in audio clips mode, a user-defined result ' + 'callback should not be provided.') + self._runner = _TaskRunner.create(graph_config, packet_callback) + self._running_mode = running_mode + + def _process_audio_clip( + self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]: + """A synchronous method to process independent audio clips. + + The call blocks the current thread until a failure status or a successful + result is returned. + + Args: + inputs: A dict contains (input stream name, data packet) pairs. + + Returns: + A dict contains (output stream name, data packet) pairs. + + Raises: + ValueError: If the task's running mode is not set to audio clips mode. + """ + if self._running_mode != _RunningMode.AUDIO_CLIPS: + raise ValueError( + 'Task is not initialized with the audio clips mode. Current running mode:' + + self._running_mode.name) + return self._runner.process(inputs) + + def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None: + """An asynchronous method to send audio stream data to the runner. + + The results will be available in the user-defined results callback. + + Args: + inputs: A dict contains (input stream name, data packet) pairs. + + Raises: + ValueError: If the task's running mode is not set to the audio stream + mode. + """ + if self._running_mode != _RunningMode.AUDIO_STREAM: + raise ValueError( + 'Task is not initialized with the audio stream mode. Current running mode:' + + self._running_mode.name) + self._runner.send(inputs) + + def close(self) -> None: + """Shuts down the mediapipe audio task instance. + + Raises: + RuntimeError: If the mediapipe audio task failed to close. + """ + self._runner.close() + + @doc_controls.do_not_generate_docs + def __enter__(self): + """Return `self` upon entering the runtime context.""" + return self + + @doc_controls.do_not_generate_docs + def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): + """Shuts down the mediapipe audio task instance on exit of the context manager. + + Raises: + RuntimeError: If the mediapipe audio task failed to close. + """ + self.close() diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py index 07938d863..76b572e86 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py @@ -197,6 +197,37 @@ class ScoreCalibrationMd: self._FILE_TYPE) +class ScoreThresholdingMd: + """A container for score thresholding [1] metadata information. + + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468 + """ + + def __init__(self, global_score_threshold: float) -> None: + """Creates a ScoreThresholdingMd object. + + Args: + global_score_threshold: The recommended global threshold below which + results are considered low-confidence and should be filtered out. + """ + self._global_score_threshold = global_score_threshold + + def create_metadata(self) -> _metadata_fb.ProcessUnitT: + """Creates the score thresholding metadata based on the information. + + Returns: + A Flatbuffers Python object of the score thresholding metadata. + """ + score_thresholding = _metadata_fb.ProcessUnitT() + score_thresholding.optionsType = ( + _metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions) + options = _metadata_fb.ScoreThresholdingOptionsT() + options.globalScoreThreshold = self._global_score_threshold + score_thresholding.options = options + return score_thresholding + + class TensorMd: """A container for common tensor metadata information. @@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd): tensor. score_calibration_md: information of the score calibration operation [2] in the classification tensor. + score_thresholding_md: information of the score thresholding [3] in the + classification tensor. [1]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 [2]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + [3]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468 """ # Min and max float values for classification results. _MIN_FLOAT = 0.0 _MAX_FLOAT = 1.0 - def __init__(self, - name: Optional[str] = None, - description: Optional[str] = None, - label_files: Optional[List[LabelFileMd]] = None, - tensor_type: Optional[int] = None, - score_calibration_md: Optional[ScoreCalibrationMd] = None, - tensor_name: Optional[str] = None) -> None: + def __init__( + self, + name: Optional[str] = None, + description: Optional[str] = None, + label_files: Optional[List[LabelFileMd]] = None, + tensor_type: Optional[int] = None, + score_calibration_md: Optional[ScoreCalibrationMd] = None, + tensor_name: Optional[str] = None, + score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None: """Initializes the instance of ClassificationTensorMd. Args: @@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd): tensor_name: name of the corresponding tensor [3] in the TFLite model. It is used to locate the corresponding classification tensor and decide the order of the tensor metadata [4] when populating model metadata. + score_thresholding_md: information of the score thresholding [5] in the + classification tensor. [1]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 [2]: @@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd): https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 [4]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 + [5]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468 """ self.score_calibration_md = score_calibration_md + self.score_thresholding_md = score_thresholding_md if tensor_type is _schema_fb.TensorType.UINT8: min_values = [_MIN_UINT8] @@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd): tensor_metadata.processUnits = [ self.score_calibration_md.create_metadata() ] + if self.score_thresholding_md: + if tensor_metadata.processUnits: + tensor_metadata.processUnits.append( + self.score_thresholding_md.create_metadata()) + else: + tensor_metadata.processUnits = [ + self.score_thresholding_md.create_metadata() + ] return tensor_metadata diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py index 5a2eaba07..3a9f91239 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py @@ -70,6 +70,18 @@ class LabelItem: locale: Optional[str] = None +@dataclasses.dataclass +class ScoreThresholding: + """Parameters to performs thresholding on output tensor values [1]. + + Attributes: + global_score_threshold: The recommended global threshold below which results + are considered low-confidence and should be filtered out. [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468 + """ + global_score_threshold: float + + class Labels(object): """Simple container holding classification labels of a particular tensor. @@ -407,6 +419,7 @@ class MetadataWriter(object): self, labels: Optional[Labels] = None, score_calibration: Optional[ScoreCalibration] = None, + score_thresholding: Optional[ScoreThresholding] = None, name: str = _OUTPUT_CLASSIFICATION_NAME, description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION ) -> 'MetadataWriter': @@ -423,6 +436,7 @@ class MetadataWriter(object): Args: labels: an instance of Labels helper class. score_calibration: an instance of ScoreCalibration helper class. + score_thresholding: an instance of ScoreThresholding. name: Metadata name of the tensor. Note that this is different from tensor name in the flatbuffer. description: human readable description of what the output is. @@ -437,6 +451,10 @@ class MetadataWriter(object): default_score=score_calibration.default_score, file_path=self._export_calibration_file('score_calibration.txt', score_calibration.parameters)) + score_thresholding_md = None + if score_thresholding: + score_thresholding_md = metadata_info.ScoreThresholdingMd( + score_thresholding.global_score_threshold) label_files = None if labels: @@ -453,6 +471,7 @@ class MetadataWriter(object): label_files=label_files, tensor_type=self._output_tensor_type(len(self._output_mds)), score_calibration_md=calibration_md, + score_thresholding_md=score_thresholding_md, ) self._output_mds.append(output_md) return self @@ -545,4 +564,3 @@ class MetadataWriterBase: A tuple of (model_with_metadata_in_bytes, metdata_json_content) """ return self.writer.populate() - diff --git a/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py b/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py index 0a054812b..eff5f553e 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py @@ -14,7 +14,8 @@ # ============================================================================== """Helper methods for writing metadata into TFLite models.""" -from typing import List +from typing import Dict, List +import zipfile from mediapipe.tasks.metadata import schema_py_generated as _schema_fb @@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph: # multiple subgraphs yet, but models with mini-benchmark may have multiple # subgraphs for acceleration evaluation purpose. return model.Subgraphs(0) + + +def create_model_asset_bundle(input_models: Dict[str, bytes], + output_path: str) -> None: + """Creates the model asset bundle. + + Args: + input_models: A dict of input models with key as the model file name and + value as the model content. + output_path: The output file path to save the model asset bundle. + """ + if not input_models or len(input_models) < 2: + raise ValueError("Needs at least two input models for model asset bundle.") + + with zipfile.ZipFile(output_path, mode="w") as zf: + for file_name, file_buffer in input_models.items(): + zf.writestr(file_name, file_buffer) diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index 92c5f4038..5ad057983 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -27,5 +27,8 @@ py_library( "//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", "//mediapipe/tasks:internal", ], - deps = ["//mediapipe/python:_framework_bindings"], + deps = [ + "//mediapipe/python:_framework_bindings", + "@com_google_protobuf//:protobuf_python", + ], ) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py index 75602c83c..33e162607 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py @@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase): malformed_calibration_file) +class ScoreThresholdingMdTest(absltest.TestCase): + _DEFAULT_GLOBAL_THRESHOLD = 0.5 + _EXPECTED_TENSOR_JSON = test_utils.get_test_data_path( + "score_thresholding_meta.json") + + def test_create_metadata_should_succeed(self): + score_thresholding_md = metadata_info.ScoreThresholdingMd( + global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD) + + score_thresholding_metadata = score_thresholding_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_process_uint( + score_thresholding_metadata)) + with open(self._EXPECTED_TENSOR_JSON, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + def _create_dummy_model_metadata_with_tensor( tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: # Create a dummy model using the tensor metadata. diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py index 51b043c7d..c59f19519 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py @@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase): } """) + def test_add_classification_output_with_score_thresholding(self): + writer = metadata_writer.MetadataWriter.create( + self.image_classifier_model_buffer) + writer.add_classification_output( + labels=metadata_writer.Labels().add(['a', 'b', 'c']), + score_thresholding=metadata_writer.ScoreThresholding( + global_score_threshold=0.5)) + _, metadata_json = writer.populate() + print(metadata_json) + self.assertJsonEqual( + metadata_json, """{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input" + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreThresholdingOptions", + "options": { + "global_score_threshold": 0.5 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" + } + """) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py index b428f8302..d2e76c57b 100644 --- a/mediapipe/tasks/python/test/test_utils.py +++ b/mediapipe/tasks/python/test/test_utils.py @@ -13,9 +13,15 @@ # limitations under the License. """Test util for MediaPipe Tasks.""" +import difflib import os from absl import flags +import six + +from google.protobuf import descriptor +from google.protobuf import descriptor_pool +from google.protobuf import text_format from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image_frame as image_frame_module @@ -53,3 +59,126 @@ def create_calibration_file(file_dir: str, with open(calibration_file, mode="w") as file: file.write(content) return calibration_file + + +def assert_proto_equals(self, + a, + b, + check_initialized=True, + normalize_numbers=True, + msg=None): + """assert_proto_equals() is useful for unit tests. + + It produces much more helpful output than assertEqual() for proto2 messages. + Fails with a useful error if a and b aren't equal. Comparison of repeated + fields matches the semantics of unittest.TestCase.assertEqual(), ie order and + extra duplicates fields matter. + + This is a fork of https://github.com/tensorflow/tensorflow/blob/ + master/tensorflow/python/util/protobuf/compare.py#L73. We use slightly + different rounding cutoffs to support Mac usage. + + Args: + self: absltest.testing.parameterized.TestCase + a: proto2 PB instance, or text string representing one. + b: proto2 PB instance -- message.Message or subclass thereof. + check_initialized: boolean, whether to fail if either a or b isn't + initialized. + normalize_numbers: boolean, whether to normalize types and precision of + numbers before comparison. + msg: if specified, is used as the error message on failure. + """ + pool = descriptor_pool.Default() + if isinstance(a, six.string_types): + a = text_format.Parse(a, b.__class__(), descriptor_pool=pool) + + for pb in a, b: + if check_initialized: + errors = pb.FindInitializationErrors() + if errors: + self.fail("Initialization errors: %s\n%s" % (errors, pb)) + if normalize_numbers: + _normalize_number_fields(pb) + + a_str = text_format.MessageToString(a, descriptor_pool=pool) + b_str = text_format.MessageToString(b, descriptor_pool=pool) + + # Some Python versions would perform regular diff instead of multi-line + # diff if string is longer than 2**16. We substitute this behavior + # with a call to unified_diff instead to have easier-to-read diffs. + # For context, see: https://bugs.python.org/issue11763. + if len(a_str) < 2**16 and len(b_str) < 2**16: + self.assertMultiLineEqual(a_str, b_str, msg=msg) + else: + diff = "".join( + difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))) + if diff: + self.fail("%s :\n%s" % (msg, diff)) + + +def _normalize_number_fields(pb): + """Normalizes types and precisions of number fields in a protocol buffer. + + Due to subtleties in the python protocol buffer implementation, it is possible + for values to have different types and precision depending on whether they + were set and retrieved directly or deserialized from a protobuf. This function + normalizes integer values to ints and longs based on width, 32-bit floats to + five digits of precision to account for python always storing them as 64-bit, + and ensures doubles are floating point for when they're set to integers. + Modifies pb in place. Recurses into nested objects. https://github.com/tensorf + low/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118 + + Args: + pb: proto2 message. + + Returns: + the given pb, modified in place. + """ + for desc, values in pb.ListFields(): + is_repeated = True + if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED: + is_repeated = False + values = [values] + + normalized_values = None + + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_SINT64): + normalized_values = [int(x) for x in values] + elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_ENUM): + normalized_values = [int(x) for x in values] + elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: + normalized_values = [round(x, 5) for x in values] + elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: + normalized_values = [round(float(x), 6) for x in values] + + if normalized_values is not None: + if is_repeated: + pb.ClearField(desc.name) + getattr(pb, desc.name).extend(normalized_values) + else: + setattr(pb, desc.name, normalized_values[0]) + + if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or + desc.type == descriptor.FieldDescriptor.TYPE_GROUP): + if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + desc.message_type.has_options and + desc.message_type.GetOptions().map_entry): + # This is a map, only recurse if the values have a message type. + if (desc.message_type.fields_by_number[2].type == + descriptor.FieldDescriptor.TYPE_MESSAGE): + for v in six.itervalues(values): + _normalize_number_fields(v) + else: + for v in values: + # recursive step + _normalize_number_fields(v) + + return pb diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index e56bcdea0..7fa487069 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -53,11 +53,6 @@ _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -# TODO: Port assertProtoEquals -def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument - pass - - def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: return _ClassificationResult(classifications=[ _Classifications( @@ -77,22 +72,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: categories=[ _Category( index=934, - score=0.7939587831497192, + score=0.793959, display_name='', category_name='cheeseburger'), _Category( index=932, - score=0.02739289402961731, + score=0.0273929, display_name='', category_name='bagel'), _Category( index=925, - score=0.01934075355529785, + score=0.0193408, display_name='', category_name='guacamole'), _Category( index=963, - score=0.006327860057353973, + score=0.00632786, display_name='', category_name='meat loaf') ], @@ -111,7 +106,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: categories=[ _Category( index=806, - score=0.9965274930000305, + score=0.996527, display_name='', category_name='soccer ball') ], @@ -189,8 +184,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - _assert_proto_equals(image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assert_proto_equals(self, image_result.to_pb2(), + expected_classification_result.to_pb2()) # Closes the classifier explicitly when the classifier is not used in # a context. classifier.close() @@ -217,8 +212,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - _assert_proto_equals(image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assert_proto_equals(self, image_result.to_pb2(), + expected_classification_result.to_pb2()) def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) @@ -235,8 +230,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(test_image, image_processing_options) # Comparing results. - _assert_proto_equals(image_result.to_pb2(), - _generate_soccer_ball_results(0).to_pb2()) + test_utils.assert_proto_equals(self, image_result.to_pb2(), + _generate_soccer_ball_results(0).to_pb2()) def test_score_threshold_option(self): custom_classifier_options = _ClassifierOptions( @@ -404,8 +399,9 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - _assert_proto_equals(classification_result.to_pb2(), - _generate_burger_results(timestamp).to_pb2()) + test_utils.assert_proto_equals( + self, classification_result.to_pb2(), + _generate_burger_results(timestamp).to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): custom_classifier_options = _ClassifierOptions(max_results=1) @@ -423,8 +419,9 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( test_image, timestamp, image_processing_options) - self.assertEqual(classification_result, - _generate_soccer_ball_results(timestamp)) + test_utils.assert_proto_equals( + self, classification_result.to_pb2(), + _generate_soccer_ball_results(timestamp).to_pb2()) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( @@ -466,8 +463,8 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - _assert_proto_equals(result.to_pb2(), - expected_result_fn(timestamp_ms).to_pb2()) + test_utils.assert_proto_equals(self, result.to_pb2(), + expected_result_fn(timestamp_ms).to_pb2()) self.assertTrue( np.array_equal(output_image.numpy_view(), self.test_image.numpy_view())) @@ -496,8 +493,9 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - _assert_proto_equals(result.to_pb2(), - _generate_soccer_ball_results(timestamp_ms).to_pb2()) + test_utils.assert_proto_equals( + self, result.to_pb2(), + _generate_soccer_ball_results(timestamp_ms).to_pb2()) self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 6d7bbab6a..8ed6e1caa 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -50,6 +50,7 @@ exports_files([ "score_calibration.txt", "score_calibration_file_meta.json", "score_calibration_tensor_meta.json", + "score_thresholding_meta.json", "labels.txt", "mobilenet_v2_1.0_224.json", "mobilenet_v2_1.0_224_quant.json", @@ -91,5 +92,6 @@ filegroup( "score_calibration.txt", "score_calibration_file_meta.json", "score_calibration_tensor_meta.json", + "score_thresholding_meta.json", ], ) diff --git a/mediapipe/tasks/testdata/metadata/score_thresholding_meta.json b/mediapipe/tasks/testdata/metadata/score_thresholding_meta.json new file mode 100644 index 000000000..d67a1aae8 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/score_thresholding_meta.json @@ -0,0 +1,14 @@ +{ + "subgraph_metadata": [ + { + "input_process_units": [ + { + "options_type": "ScoreThresholdingOptions", + "options": { + "global_score_threshold": 0.5 + } + } + ] + } + ] +} diff --git a/mediapipe/util/time_series_util.h b/mediapipe/util/time_series_util.h index a6a5911a6..afa66acc6 100644 --- a/mediapipe/util/time_series_util.h +++ b/mediapipe/util/time_series_util.h @@ -58,7 +58,7 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, absl::Status FillMultiStreamTimeSeriesHeaderIfValid( const Packet& header_packet, MultiStreamTimeSeriesHeader* header); -// Returnsabsl::Status::OK iff options contains an extension of type +// Returns absl::Status::OK iff options contains an extension of type // OptionsClass. template absl::Status HasOptionsExtension(const CalculatorOptions& options) { @@ -75,7 +75,7 @@ absl::Status HasOptionsExtension(const CalculatorOptions& options) { return absl::InvalidArgumentError(error_message); } -// Returnsabsl::Status::OK if the shape of 'matrix' is consistent +// Returns absl::Status::OK if the shape of 'matrix' is consistent // with the num_samples and num_channels fields present in 'header'. // The corresponding matrix dimensions of unset header fields are // ignored, so e.g. an empty header (which is not valid according to diff --git a/mediapipe/util/tracking/tracking.cc b/mediapipe/util/tracking/tracking.cc index 7e80cd5ce..88ba39807 100644 --- a/mediapipe/util/tracking/tracking.cc +++ b/mediapipe/util/tracking/tracking.cc @@ -1323,10 +1323,9 @@ void MotionBox::GetSpatialGaussWeights(const MotionBoxState& box_state, const float space_sigma_x = std::max( options_.spatial_sigma(), box_state.inlier_width() * inv_box_domain.x() * 0.5f * box_state.prior_weight() / 1.65f); - const float space_sigma_y = options_.spatial_sigma(); - std::max(options_.spatial_sigma(), box_state.inlier_height() * - inv_box_domain.y() * 0.5f * - box_state.prior_weight() / 1.65f); + const float space_sigma_y = std::max( + options_.spatial_sigma(), box_state.inlier_height() * inv_box_domain.y() * + 0.5f * box_state.prior_weight() / 1.65f); *spatial_gauss_x = -0.5f / (space_sigma_x * space_sigma_x); *spatial_gauss_y = -0.5f / (space_sigma_y * space_sigma_y); diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 4dcbc3bd9..e47dc9812 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -694,6 +694,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"], ) + http_file( + name = "com_google_mediapipe_score_thresholding_meta_json", + sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19", + urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"], + ) + http_file( name = "com_google_mediapipe_segmentation_golden_rotation0_png", sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",