Merge branch 'master' into gesture-recognizer-python

This commit is contained in:
Kinar R 2022-11-02 04:11:18 +05:30 committed by GitHub
commit 3a2f30185f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 624 additions and 40 deletions

View File

@ -328,6 +328,7 @@ cc_library(
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -344,6 +345,7 @@ cc_test(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",

View File

@ -18,6 +18,7 @@
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
};
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
class ConcatenateClassificationListCalculator
: public ConcatenateListsCalculator<Classification, ClassificationList> {
protected:
int ListSize(const ClassificationList& list) const override {
return list.classification_size();
}
const Classification GetItem(const ClassificationList& list,
int idx) const override {
return list.classification(idx);
}
Classification* AddItem(ClassificationList& list) const override {
return list.add_classification();
}
};
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
}
}
void AddInputClassificationLists(
const std::vector<ClassificationList>& input_classifications_vec,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < input_classifications_vec.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<ClassificationList>(input_classifications_vec[i])
.At(Timestamp(timestamp)));
}
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/3,
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size());
}
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateClassificationListCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
)pb");
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb");
std::vector<ClassificationList> inputs = {input_0, input_1};
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
auto result = outputs[0].Get<ClassificationList>();
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb"),
EqualsProto(result));
}
} // namespace mediapipe

View File

@ -0,0 +1,37 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "audio_task_running_mode",
srcs = ["audio_task_running_mode.py"],
)
py_library(
name = "base_audio_task_api",
srcs = [
"base_audio_task_api.py",
],
deps = [
":audio_task_running_mode",
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -0,0 +1,16 @@
"""Copyright 2022 The MediaPipe Authors.
All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

View File

@ -0,0 +1,29 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The running mode of MediaPipe Audio Tasks."""
import enum
class AudioTaskRunningMode(enum.Enum):
"""MediaPipe audio task running mode.
Attributes:
AUDIO_CLIPS: The mode for running a mediapipe audio task on independent
audio clips.
AUDIO_STREAM: The mode for running a mediapipe audio task on an audio
stream, such as from microphone.
"""
AUDIO_CLIPS = 'AUDIO_CLIPS'
AUDIO_STREAM = 'AUDIO_STREAM'

View File

@ -0,0 +1,123 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe audio task base api."""
from typing import Callable, Mapping, Optional
from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_TaskRunner = task_runner_module.TaskRunner
_Packet = packet_module.Packet
_RunningMode = running_mode_module.AudioTaskRunningMode
class BaseAudioTaskApi(object):
"""The base class of the user-facing mediapipe audio task api classes."""
def __init__(
self,
graph_config: calculator_pb2.CalculatorGraphConfig,
running_mode: _RunningMode,
packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]],
None]] = None
) -> None:
"""Initializes the `BaseAudioTaskApi` object.
Args:
graph_config: The mediapipe audio task graph config proto.
running_mode: The running mode of the mediapipe audio task.
packet_callback: The optional packet callback for getting results
asynchronously in the audio stream mode.
Raises:
ValueError: The packet callback is not properly set based on the task's
running mode.
"""
if running_mode == _RunningMode.AUDIO_STREAM:
if packet_callback is None:
raise ValueError(
'The audio task is in audio stream mode, a user-defined result '
'callback must be provided.')
elif packet_callback:
raise ValueError(
'The audio task is in audio clips mode, a user-defined result '
'callback should not be provided.')
self._runner = _TaskRunner.create(graph_config, packet_callback)
self._running_mode = running_mode
def _process_audio_clip(
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
"""A synchronous method to process independent audio clips.
The call blocks the current thread until a failure status or a successful
result is returned.
Args:
inputs: A dict contains (input stream name, data packet) pairs.
Returns:
A dict contains (output stream name, data packet) pairs.
Raises:
ValueError: If the task's running mode is not set to audio clips mode.
"""
if self._running_mode != _RunningMode.AUDIO_CLIPS:
raise ValueError(
'Task is not initialized with the audio clips mode. Current running mode:'
+ self._running_mode.name)
return self._runner.process(inputs)
def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
"""An asynchronous method to send audio stream data to the runner.
The results will be available in the user-defined results callback.
Args:
inputs: A dict contains (input stream name, data packet) pairs.
Raises:
ValueError: If the task's running mode is not set to the audio stream
mode.
"""
if self._running_mode != _RunningMode.AUDIO_STREAM:
raise ValueError(
'Task is not initialized with the audio stream mode. Current running mode:'
+ self._running_mode.name)
self._runner.send(inputs)
def close(self) -> None:
"""Shuts down the mediapipe audio task instance.
Raises:
RuntimeError: If the mediapipe audio task failed to close.
"""
self._runner.close()
@doc_controls.do_not_generate_docs
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self
@doc_controls.do_not_generate_docs
def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
"""Shuts down the mediapipe audio task instance on exit of the context manager.
Raises:
RuntimeError: If the mediapipe audio task failed to close.
"""
self.close()

View File

@ -197,6 +197,37 @@ class ScoreCalibrationMd:
self._FILE_TYPE)
class ScoreThresholdingMd:
"""A container for score thresholding [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
def __init__(self, global_score_threshold: float) -> None:
"""Creates a ScoreThresholdingMd object.
Args:
global_score_threshold: The recommended global threshold below which
results are considered low-confidence and should be filtered out.
"""
self._global_score_threshold = global_score_threshold
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the score thresholding metadata based on the information.
Returns:
A Flatbuffers Python object of the score thresholding metadata.
"""
score_thresholding = _metadata_fb.ProcessUnitT()
score_thresholding.optionsType = (
_metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions)
options = _metadata_fb.ScoreThresholdingOptionsT()
options.globalScoreThreshold = self._global_score_threshold
score_thresholding.options = options
return score_thresholding
class TensorMd:
"""A container for common tensor metadata information.
@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd):
tensor.
score_calibration_md: information of the score calibration operation [2] in
the classification tensor.
score_thresholding_md: information of the score thresholding [3] in the
classification tensor.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
# Min and max float values for classification results.
_MIN_FLOAT = 0.0
_MAX_FLOAT = 1.0
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None,
tensor_type: Optional[int] = None,
score_calibration_md: Optional[ScoreCalibrationMd] = None,
tensor_name: Optional[str] = None) -> None:
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None,
tensor_type: Optional[int] = None,
score_calibration_md: Optional[ScoreCalibrationMd] = None,
tensor_name: Optional[str] = None,
score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None:
"""Initializes the instance of ClassificationTensorMd.
Args:
@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd):
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
is used to locate the corresponding classification tensor and decide the
order of the tensor metadata [4] when populating model metadata.
score_thresholding_md: information of the score thresholding [5] in the
classification tensor.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd):
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
[5]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
self.score_calibration_md = score_calibration_md
self.score_thresholding_md = score_thresholding_md
if tensor_type is _schema_fb.TensorType.UINT8:
min_values = [_MIN_UINT8]
@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd):
tensor_metadata.processUnits = [
self.score_calibration_md.create_metadata()
]
if self.score_thresholding_md:
if tensor_metadata.processUnits:
tensor_metadata.processUnits.append(
self.score_thresholding_md.create_metadata())
else:
tensor_metadata.processUnits = [
self.score_thresholding_md.create_metadata()
]
return tensor_metadata

View File

@ -70,6 +70,18 @@ class LabelItem:
locale: Optional[str] = None
@dataclasses.dataclass
class ScoreThresholding:
"""Parameters to performs thresholding on output tensor values [1].
Attributes:
global_score_threshold: The recommended global threshold below which results
are considered low-confidence and should be filtered out. [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
global_score_threshold: float
class Labels(object):
"""Simple container holding classification labels of a particular tensor.
@ -407,6 +419,7 @@ class MetadataWriter(object):
self,
labels: Optional[Labels] = None,
score_calibration: Optional[ScoreCalibration] = None,
score_thresholding: Optional[ScoreThresholding] = None,
name: str = _OUTPUT_CLASSIFICATION_NAME,
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
) -> 'MetadataWriter':
@ -423,6 +436,7 @@ class MetadataWriter(object):
Args:
labels: an instance of Labels helper class.
score_calibration: an instance of ScoreCalibration helper class.
score_thresholding: an instance of ScoreThresholding.
name: Metadata name of the tensor. Note that this is different from tensor
name in the flatbuffer.
description: human readable description of what the output is.
@ -437,6 +451,10 @@ class MetadataWriter(object):
default_score=score_calibration.default_score,
file_path=self._export_calibration_file('score_calibration.txt',
score_calibration.parameters))
score_thresholding_md = None
if score_thresholding:
score_thresholding_md = metadata_info.ScoreThresholdingMd(
score_thresholding.global_score_threshold)
label_files = None
if labels:
@ -453,6 +471,7 @@ class MetadataWriter(object):
label_files=label_files,
tensor_type=self._output_tensor_type(len(self._output_mds)),
score_calibration_md=calibration_md,
score_thresholding_md=score_thresholding_md,
)
self._output_mds.append(output_md)
return self
@ -545,4 +564,3 @@ class MetadataWriterBase:
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
"""
return self.writer.populate()

View File

@ -14,7 +14,8 @@
# ==============================================================================
"""Helper methods for writing metadata into TFLite models."""
from typing import List
from typing import Dict, List
import zipfile
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
# multiple subgraphs yet, but models with mini-benchmark may have multiple
# subgraphs for acceleration evaluation purpose.
return model.Subgraphs(0)
def create_model_asset_bundle(input_models: Dict[str, bytes],
output_path: str) -> None:
"""Creates the model asset bundle.
Args:
input_models: A dict of input models with key as the model file name and
value as the model content.
output_path: The output file path to save the model asset bundle.
"""
if not input_models or len(input_models) < 2:
raise ValueError("Needs at least two input models for model asset bundle.")
with zipfile.ZipFile(output_path, mode="w") as zf:
for file_name, file_buffer in input_models.items():
zf.writestr(file_name, file_buffer)

View File

@ -27,5 +27,8 @@ py_library(
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
"//mediapipe/tasks:internal",
],
deps = ["//mediapipe/python:_framework_bindings"],
deps = [
"//mediapipe/python:_framework_bindings",
"@com_google_protobuf//:protobuf_python",
],
)

View File

@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase):
malformed_calibration_file)
class ScoreThresholdingMdTest(absltest.TestCase):
_DEFAULT_GLOBAL_THRESHOLD = 0.5
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
"score_thresholding_meta.json")
def test_create_metadata_should_succeed(self):
score_thresholding_md = metadata_info.ScoreThresholdingMd(
global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD)
score_thresholding_metadata = score_thresholding_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(
score_thresholding_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata.

View File

@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase):
}
""")
def test_add_classification_output_with_score_thresholding(self):
writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer)
writer.add_classification_output(
labels=metadata_writer.Labels().add(['a', 'b', 'c']),
score_thresholding=metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
_, metadata_json = writer.populate()
print(metadata_json)
self.assertJsonEqual(
metadata_json, """{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input"
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}
""")
if __name__ == '__main__':
absltest.main()

View File

@ -13,9 +13,15 @@
# limitations under the License.
"""Test util for MediaPipe Tasks."""
import difflib
import os
from absl import flags
import six
from google.protobuf import descriptor
from google.protobuf import descriptor_pool
from google.protobuf import text_format
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module
@ -53,3 +59,126 @@ def create_calibration_file(file_dir: str,
with open(calibration_file, mode="w") as file:
file.write(content)
return calibration_file
def assert_proto_equals(self,
a,
b,
check_initialized=True,
normalize_numbers=True,
msg=None):
"""assert_proto_equals() is useful for unit tests.
It produces much more helpful output than assertEqual() for proto2 messages.
Fails with a useful error if a and b aren't equal. Comparison of repeated
fields matches the semantics of unittest.TestCase.assertEqual(), ie order and
extra duplicates fields matter.
This is a fork of https://github.com/tensorflow/tensorflow/blob/
master/tensorflow/python/util/protobuf/compare.py#L73. We use slightly
different rounding cutoffs to support Mac usage.
Args:
self: absltest.testing.parameterized.TestCase
a: proto2 PB instance, or text string representing one.
b: proto2 PB instance -- message.Message or subclass thereof.
check_initialized: boolean, whether to fail if either a or b isn't
initialized.
normalize_numbers: boolean, whether to normalize types and precision of
numbers before comparison.
msg: if specified, is used as the error message on failure.
"""
pool = descriptor_pool.Default()
if isinstance(a, six.string_types):
a = text_format.Parse(a, b.__class__(), descriptor_pool=pool)
for pb in a, b:
if check_initialized:
errors = pb.FindInitializationErrors()
if errors:
self.fail("Initialization errors: %s\n%s" % (errors, pb))
if normalize_numbers:
_normalize_number_fields(pb)
a_str = text_format.MessageToString(a, descriptor_pool=pool)
b_str = text_format.MessageToString(b, descriptor_pool=pool)
# Some Python versions would perform regular diff instead of multi-line
# diff if string is longer than 2**16. We substitute this behavior
# with a call to unified_diff instead to have easier-to-read diffs.
# For context, see: https://bugs.python.org/issue11763.
if len(a_str) < 2**16 and len(b_str) < 2**16:
self.assertMultiLineEqual(a_str, b_str, msg=msg)
else:
diff = "".join(
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
if diff:
self.fail("%s :\n%s" % (msg, diff))
def _normalize_number_fields(pb):
"""Normalizes types and precisions of number fields in a protocol buffer.
Due to subtleties in the python protocol buffer implementation, it is possible
for values to have different types and precision depending on whether they
were set and retrieved directly or deserialized from a protobuf. This function
normalizes integer values to ints and longs based on width, 32-bit floats to
five digits of precision to account for python always storing them as 64-bit,
and ensures doubles are floating point for when they're set to integers.
Modifies pb in place. Recurses into nested objects. https://github.com/tensorf
low/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
Args:
pb: proto2 message.
Returns:
the given pb, modified in place.
"""
for desc, values in pb.ListFields():
is_repeated = True
if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
is_repeated = False
values = [values]
normalized_values = None
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler.
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_SINT64):
normalized_values = [int(x) for x in values]
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_SINT32,
descriptor.FieldDescriptor.TYPE_ENUM):
normalized_values = [int(x) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
normalized_values = [round(x, 5) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
normalized_values = [round(float(x), 6) for x in values]
if normalized_values is not None:
if is_repeated:
pb.ClearField(desc.name)
getattr(pb, desc.name).extend(normalized_values)
else:
setattr(pb, desc.name, normalized_values[0])
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
desc.message_type.has_options and
desc.message_type.GetOptions().map_entry):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
for v in six.itervalues(values):
_normalize_number_fields(v)
else:
for v in values:
# recursive step
_normalize_number_fields(v)
return pb

View File

@ -53,11 +53,6 @@ _SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3
# TODO: Port assertProtoEquals
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
pass
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
@ -77,22 +72,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
categories=[
_Category(
index=934,
score=0.7939587831497192,
score=0.793959,
display_name='',
category_name='cheeseburger'),
_Category(
index=932,
score=0.02739289402961731,
score=0.0273929,
display_name='',
category_name='bagel'),
_Category(
index=925,
score=0.01934075355529785,
score=0.0193408,
display_name='',
category_name='guacamole'),
_Category(
index=963,
score=0.006327860057353973,
score=0.00632786,
display_name='',
category_name='meat loaf')
],
@ -111,7 +106,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
categories=[
_Category(
index=806,
score=0.9965274930000305,
score=0.996527,
display_name='',
category_name='soccer ball')
],
@ -189,8 +184,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
_assert_proto_equals(image_result.to_pb2(),
expected_classification_result.to_pb2())
test_utils.assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in
# a context.
classifier.close()
@ -217,8 +212,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
_assert_proto_equals(image_result.to_pb2(),
expected_classification_result.to_pb2())
test_utils.assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
@ -235,8 +230,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input.
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
_assert_proto_equals(image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2())
test_utils.assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2())
def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions(
@ -404,8 +399,9 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
self.test_image, timestamp)
_assert_proto_equals(classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2())
test_utils.assert_proto_equals(
self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1)
@ -423,8 +419,9 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
test_image, timestamp, image_processing_options)
self.assertEqual(classification_result,
_generate_soccer_ball_results(timestamp))
test_utils.assert_proto_equals(
self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2())
def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions(
@ -466,8 +463,8 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
_assert_proto_equals(result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2())
test_utils.assert_proto_equals(self, result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2())
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
@ -496,8 +493,9 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
_assert_proto_equals(result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2())
test_utils.assert_proto_equals(
self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2())
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertLess(observed_timestamp_ms, timestamp_ms)

View File

@ -50,6 +50,7 @@ exports_files([
"score_calibration.txt",
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",
"score_thresholding_meta.json",
"labels.txt",
"mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json",
@ -91,5 +92,6 @@ filegroup(
"score_calibration.txt",
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",
"score_thresholding_meta.json",
],
)

View File

@ -0,0 +1,14 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
]
}
]
}

View File

@ -58,7 +58,7 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
absl::Status FillMultiStreamTimeSeriesHeaderIfValid(
const Packet& header_packet, MultiStreamTimeSeriesHeader* header);
// Returnsabsl::Status::OK iff options contains an extension of type
// Returns absl::Status::OK iff options contains an extension of type
// OptionsClass.
template <typename OptionsClass>
absl::Status HasOptionsExtension(const CalculatorOptions& options) {
@ -75,7 +75,7 @@ absl::Status HasOptionsExtension(const CalculatorOptions& options) {
return absl::InvalidArgumentError(error_message);
}
// Returnsabsl::Status::OK if the shape of 'matrix' is consistent
// Returns absl::Status::OK if the shape of 'matrix' is consistent
// with the num_samples and num_channels fields present in 'header'.
// The corresponding matrix dimensions of unset header fields are
// ignored, so e.g. an empty header (which is not valid according to

View File

@ -1323,10 +1323,9 @@ void MotionBox::GetSpatialGaussWeights(const MotionBoxState& box_state,
const float space_sigma_x = std::max(
options_.spatial_sigma(), box_state.inlier_width() * inv_box_domain.x() *
0.5f * box_state.prior_weight() / 1.65f);
const float space_sigma_y = options_.spatial_sigma();
std::max(options_.spatial_sigma(), box_state.inlier_height() *
inv_box_domain.y() * 0.5f *
box_state.prior_weight() / 1.65f);
const float space_sigma_y = std::max(
options_.spatial_sigma(), box_state.inlier_height() * inv_box_domain.y() *
0.5f * box_state.prior_weight() / 1.65f);
*spatial_gauss_x = -0.5f / (space_sigma_x * space_sigma_x);
*spatial_gauss_y = -0.5f / (space_sigma_y * space_sigma_y);

View File

@ -694,6 +694,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
)
http_file(
name = "com_google_mediapipe_score_thresholding_meta_json",
sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19",
urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"],
)
http_file(
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",