Merge branch 'master' into gesture-recognizer-python
This commit is contained in:
commit
3a2f30185f
|
@ -328,6 +328,7 @@ cc_library(
|
|||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -344,6 +345,7 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
|
|||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
|
||||
|
||||
class ConcatenateClassificationListCalculator
|
||||
: public ConcatenateListsCalculator<Classification, ClassificationList> {
|
||||
protected:
|
||||
int ListSize(const ClassificationList& list) const override {
|
||||
return list.classification_size();
|
||||
}
|
||||
const Classification GetItem(const ClassificationList& list,
|
||||
int idx) const override {
|
||||
return list.classification(idx);
|
||||
}
|
||||
Classification* AddItem(ClassificationList& list) const override {
|
||||
return list.add_classification();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
|
|||
}
|
||||
}
|
||||
|
||||
void AddInputClassificationLists(
|
||||
const std::vector<ClassificationList>& input_classifications_vec,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < input_classifications_vec.size(); ++i) {
|
||||
runner->MutableInputs()->Index(i).packets.push_back(
|
||||
MakePacket<ClassificationList>(input_classifications_vec[i])
|
||||
.At(Timestamp(timestamp)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
|
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
|
|||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("ConcatenateClassificationListCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||
)pb");
|
||||
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||
)pb");
|
||||
std::vector<ClassificationList> inputs = {input_0, input_1};
|
||||
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
auto result = outputs[0].Get<ClassificationList>();
|
||||
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||
)pb"),
|
||||
EqualsProto(result));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
37
mediapipe/tasks/python/audio/core/BUILD
Normal file
37
mediapipe/tasks/python/audio/core/BUILD
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "audio_task_running_mode",
|
||||
srcs = ["audio_task_running_mode.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_audio_task_api",
|
||||
srcs = [
|
||||
"base_audio_task_api.py",
|
||||
],
|
||||
deps = [
|
||||
":audio_task_running_mode",
|
||||
"//mediapipe/framework:calculator_py_pb2",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
16
mediapipe/tasks/python/audio/core/__init__.py
Normal file
16
mediapipe/tasks/python/audio/core/__init__.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
"""Copyright 2022 The MediaPipe Authors.
|
||||
|
||||
All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
29
mediapipe/tasks/python/audio/core/audio_task_running_mode.py
Normal file
29
mediapipe/tasks/python/audio/core/audio_task_running_mode.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The running mode of MediaPipe Audio Tasks."""
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class AudioTaskRunningMode(enum.Enum):
|
||||
"""MediaPipe audio task running mode.
|
||||
|
||||
Attributes:
|
||||
AUDIO_CLIPS: The mode for running a mediapipe audio task on independent
|
||||
audio clips.
|
||||
AUDIO_STREAM: The mode for running a mediapipe audio task on an audio
|
||||
stream, such as from microphone.
|
||||
"""
|
||||
AUDIO_CLIPS = 'AUDIO_CLIPS'
|
||||
AUDIO_STREAM = 'AUDIO_STREAM'
|
123
mediapipe/tasks/python/audio/core/base_audio_task_api.py
Normal file
123
mediapipe/tasks/python/audio/core/base_audio_task_api.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe audio task base api."""
|
||||
|
||||
from typing import Callable, Mapping, Optional
|
||||
|
||||
from mediapipe.framework import calculator_pb2
|
||||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_TaskRunner = task_runner_module.TaskRunner
|
||||
_Packet = packet_module.Packet
|
||||
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||
|
||||
|
||||
class BaseAudioTaskApi(object):
|
||||
"""The base class of the user-facing mediapipe audio task api classes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_config: calculator_pb2.CalculatorGraphConfig,
|
||||
running_mode: _RunningMode,
|
||||
packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]],
|
||||
None]] = None
|
||||
) -> None:
|
||||
"""Initializes the `BaseAudioTaskApi` object.
|
||||
|
||||
Args:
|
||||
graph_config: The mediapipe audio task graph config proto.
|
||||
running_mode: The running mode of the mediapipe audio task.
|
||||
packet_callback: The optional packet callback for getting results
|
||||
asynchronously in the audio stream mode.
|
||||
|
||||
Raises:
|
||||
ValueError: The packet callback is not properly set based on the task's
|
||||
running mode.
|
||||
"""
|
||||
if running_mode == _RunningMode.AUDIO_STREAM:
|
||||
if packet_callback is None:
|
||||
raise ValueError(
|
||||
'The audio task is in audio stream mode, a user-defined result '
|
||||
'callback must be provided.')
|
||||
elif packet_callback:
|
||||
raise ValueError(
|
||||
'The audio task is in audio clips mode, a user-defined result '
|
||||
'callback should not be provided.')
|
||||
self._runner = _TaskRunner.create(graph_config, packet_callback)
|
||||
self._running_mode = running_mode
|
||||
|
||||
def _process_audio_clip(
|
||||
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
|
||||
"""A synchronous method to process independent audio clips.
|
||||
|
||||
The call blocks the current thread until a failure status or a successful
|
||||
result is returned.
|
||||
|
||||
Args:
|
||||
inputs: A dict contains (input stream name, data packet) pairs.
|
||||
|
||||
Returns:
|
||||
A dict contains (output stream name, data packet) pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task's running mode is not set to audio clips mode.
|
||||
"""
|
||||
if self._running_mode != _RunningMode.AUDIO_CLIPS:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the audio clips mode. Current running mode:'
|
||||
+ self._running_mode.name)
|
||||
return self._runner.process(inputs)
|
||||
|
||||
def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
|
||||
"""An asynchronous method to send audio stream data to the runner.
|
||||
|
||||
The results will be available in the user-defined results callback.
|
||||
|
||||
Args:
|
||||
inputs: A dict contains (input stream name, data packet) pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task's running mode is not set to the audio stream
|
||||
mode.
|
||||
"""
|
||||
if self._running_mode != _RunningMode.AUDIO_STREAM:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the audio stream mode. Current running mode:'
|
||||
+ self._running_mode.name)
|
||||
self._runner.send(inputs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Shuts down the mediapipe audio task instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the mediapipe audio task failed to close.
|
||||
"""
|
||||
self._runner.close()
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def __enter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
|
||||
"""Shuts down the mediapipe audio task instance on exit of the context manager.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the mediapipe audio task failed to close.
|
||||
"""
|
||||
self.close()
|
|
@ -197,6 +197,37 @@ class ScoreCalibrationMd:
|
|||
self._FILE_TYPE)
|
||||
|
||||
|
||||
class ScoreThresholdingMd:
|
||||
"""A container for score thresholding [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
|
||||
def __init__(self, global_score_threshold: float) -> None:
|
||||
"""Creates a ScoreThresholdingMd object.
|
||||
|
||||
Args:
|
||||
global_score_threshold: The recommended global threshold below which
|
||||
results are considered low-confidence and should be filtered out.
|
||||
"""
|
||||
self._global_score_threshold = global_score_threshold
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||
"""Creates the score thresholding metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the score thresholding metadata.
|
||||
"""
|
||||
score_thresholding = _metadata_fb.ProcessUnitT()
|
||||
score_thresholding.optionsType = (
|
||||
_metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions)
|
||||
options = _metadata_fb.ScoreThresholdingOptionsT()
|
||||
options.globalScoreThreshold = self._global_score_threshold
|
||||
score_thresholding.options = options
|
||||
return score_thresholding
|
||||
|
||||
|
||||
class TensorMd:
|
||||
"""A container for common tensor metadata information.
|
||||
|
||||
|
@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd):
|
|||
tensor.
|
||||
score_calibration_md: information of the score calibration operation [2] in
|
||||
the classification tensor.
|
||||
score_thresholding_md: information of the score thresholding [3] in the
|
||||
classification tensor.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
|
||||
# Min and max float values for classification results.
|
||||
_MIN_FLOAT = 0.0
|
||||
_MAX_FLOAT = 1.0
|
||||
|
||||
def __init__(self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
label_files: Optional[List[LabelFileMd]] = None,
|
||||
tensor_type: Optional[int] = None,
|
||||
score_calibration_md: Optional[ScoreCalibrationMd] = None,
|
||||
tensor_name: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
label_files: Optional[List[LabelFileMd]] = None,
|
||||
tensor_type: Optional[int] = None,
|
||||
score_calibration_md: Optional[ScoreCalibrationMd] = None,
|
||||
tensor_name: Optional[str] = None,
|
||||
score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None:
|
||||
"""Initializes the instance of ClassificationTensorMd.
|
||||
|
||||
Args:
|
||||
|
@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd):
|
|||
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
|
||||
is used to locate the corresponding classification tensor and decide the
|
||||
order of the tensor metadata [4] when populating model metadata.
|
||||
score_thresholding_md: information of the score thresholding [5] in the
|
||||
classification tensor.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[2]:
|
||||
|
@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd):
|
|||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||
[4]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||
[5]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
self.score_calibration_md = score_calibration_md
|
||||
self.score_thresholding_md = score_thresholding_md
|
||||
|
||||
if tensor_type is _schema_fb.TensorType.UINT8:
|
||||
min_values = [_MIN_UINT8]
|
||||
|
@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd):
|
|||
tensor_metadata.processUnits = [
|
||||
self.score_calibration_md.create_metadata()
|
||||
]
|
||||
if self.score_thresholding_md:
|
||||
if tensor_metadata.processUnits:
|
||||
tensor_metadata.processUnits.append(
|
||||
self.score_thresholding_md.create_metadata())
|
||||
else:
|
||||
tensor_metadata.processUnits = [
|
||||
self.score_thresholding_md.create_metadata()
|
||||
]
|
||||
return tensor_metadata
|
||||
|
|
|
@ -70,6 +70,18 @@ class LabelItem:
|
|||
locale: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScoreThresholding:
|
||||
"""Parameters to performs thresholding on output tensor values [1].
|
||||
|
||||
Attributes:
|
||||
global_score_threshold: The recommended global threshold below which results
|
||||
are considered low-confidence and should be filtered out. [1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
global_score_threshold: float
|
||||
|
||||
|
||||
class Labels(object):
|
||||
"""Simple container holding classification labels of a particular tensor.
|
||||
|
||||
|
@ -407,6 +419,7 @@ class MetadataWriter(object):
|
|||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
score_calibration: Optional[ScoreCalibration] = None,
|
||||
score_thresholding: Optional[ScoreThresholding] = None,
|
||||
name: str = _OUTPUT_CLASSIFICATION_NAME,
|
||||
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
|
||||
) -> 'MetadataWriter':
|
||||
|
@ -423,6 +436,7 @@ class MetadataWriter(object):
|
|||
Args:
|
||||
labels: an instance of Labels helper class.
|
||||
score_calibration: an instance of ScoreCalibration helper class.
|
||||
score_thresholding: an instance of ScoreThresholding.
|
||||
name: Metadata name of the tensor. Note that this is different from tensor
|
||||
name in the flatbuffer.
|
||||
description: human readable description of what the output is.
|
||||
|
@ -437,6 +451,10 @@ class MetadataWriter(object):
|
|||
default_score=score_calibration.default_score,
|
||||
file_path=self._export_calibration_file('score_calibration.txt',
|
||||
score_calibration.parameters))
|
||||
score_thresholding_md = None
|
||||
if score_thresholding:
|
||||
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||
score_thresholding.global_score_threshold)
|
||||
|
||||
label_files = None
|
||||
if labels:
|
||||
|
@ -453,6 +471,7 @@ class MetadataWriter(object):
|
|||
label_files=label_files,
|
||||
tensor_type=self._output_tensor_type(len(self._output_mds)),
|
||||
score_calibration_md=calibration_md,
|
||||
score_thresholding_md=score_thresholding_md,
|
||||
)
|
||||
self._output_mds.append(output_md)
|
||||
return self
|
||||
|
@ -545,4 +564,3 @@ class MetadataWriterBase:
|
|||
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||
"""
|
||||
return self.writer.populate()
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# ==============================================================================
|
||||
"""Helper methods for writing metadata into TFLite models."""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
import zipfile
|
||||
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
|
||||
|
@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
|
|||
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
||||
# subgraphs for acceleration evaluation purpose.
|
||||
return model.Subgraphs(0)
|
||||
|
||||
|
||||
def create_model_asset_bundle(input_models: Dict[str, bytes],
|
||||
output_path: str) -> None:
|
||||
"""Creates the model asset bundle.
|
||||
|
||||
Args:
|
||||
input_models: A dict of input models with key as the model file name and
|
||||
value as the model content.
|
||||
output_path: The output file path to save the model asset bundle.
|
||||
"""
|
||||
if not input_models or len(input_models) < 2:
|
||||
raise ValueError("Needs at least two input models for model asset bundle.")
|
||||
|
||||
with zipfile.ZipFile(output_path, mode="w") as zf:
|
||||
for file_name, file_buffer in input_models.items():
|
||||
zf.writestr(file_name, file_buffer)
|
||||
|
|
|
@ -27,5 +27,8 @@ py_library(
|
|||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
|
||||
"//mediapipe/tasks:internal",
|
||||
],
|
||||
deps = ["//mediapipe/python:_framework_bindings"],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase):
|
|||
malformed_calibration_file)
|
||||
|
||||
|
||||
class ScoreThresholdingMdTest(absltest.TestCase):
|
||||
_DEFAULT_GLOBAL_THRESHOLD = 0.5
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"score_thresholding_meta.json")
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||
global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD)
|
||||
|
||||
score_thresholding_metadata = score_thresholding_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_process_uint(
|
||||
score_thresholding_metadata))
|
||||
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
def _create_dummy_model_metadata_with_tensor(
|
||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||
# Create a dummy model using the tensor metadata.
|
||||
|
|
|
@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
}
|
||||
""")
|
||||
|
||||
def test_add_classification_output_with_score_thresholding(self):
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
self.image_classifier_model_buffer)
|
||||
writer.add_classification_output(
|
||||
labels=metadata_writer.Labels().add(['a', 'b', 'c']),
|
||||
score_thresholding=metadata_writer.ScoreThresholding(
|
||||
global_score_threshold=0.5))
|
||||
_, metadata_json = writer.populate()
|
||||
print(metadata_json)
|
||||
self.assertJsonEqual(
|
||||
metadata_json, """{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "input"
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "ScoreThresholdingOptions",
|
||||
"options": {
|
||||
"global_score_threshold": 0.5
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -13,9 +13,15 @@
|
|||
# limitations under the License.
|
||||
"""Test util for MediaPipe Tasks."""
|
||||
|
||||
import difflib
|
||||
import os
|
||||
|
||||
from absl import flags
|
||||
import six
|
||||
|
||||
from google.protobuf import descriptor
|
||||
from google.protobuf import descriptor_pool
|
||||
from google.protobuf import text_format
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||
|
@ -53,3 +59,126 @@ def create_calibration_file(file_dir: str,
|
|||
with open(calibration_file, mode="w") as file:
|
||||
file.write(content)
|
||||
return calibration_file
|
||||
|
||||
|
||||
def assert_proto_equals(self,
|
||||
a,
|
||||
b,
|
||||
check_initialized=True,
|
||||
normalize_numbers=True,
|
||||
msg=None):
|
||||
"""assert_proto_equals() is useful for unit tests.
|
||||
|
||||
It produces much more helpful output than assertEqual() for proto2 messages.
|
||||
Fails with a useful error if a and b aren't equal. Comparison of repeated
|
||||
fields matches the semantics of unittest.TestCase.assertEqual(), ie order and
|
||||
extra duplicates fields matter.
|
||||
|
||||
This is a fork of https://github.com/tensorflow/tensorflow/blob/
|
||||
master/tensorflow/python/util/protobuf/compare.py#L73. We use slightly
|
||||
different rounding cutoffs to support Mac usage.
|
||||
|
||||
Args:
|
||||
self: absltest.testing.parameterized.TestCase
|
||||
a: proto2 PB instance, or text string representing one.
|
||||
b: proto2 PB instance -- message.Message or subclass thereof.
|
||||
check_initialized: boolean, whether to fail if either a or b isn't
|
||||
initialized.
|
||||
normalize_numbers: boolean, whether to normalize types and precision of
|
||||
numbers before comparison.
|
||||
msg: if specified, is used as the error message on failure.
|
||||
"""
|
||||
pool = descriptor_pool.Default()
|
||||
if isinstance(a, six.string_types):
|
||||
a = text_format.Parse(a, b.__class__(), descriptor_pool=pool)
|
||||
|
||||
for pb in a, b:
|
||||
if check_initialized:
|
||||
errors = pb.FindInitializationErrors()
|
||||
if errors:
|
||||
self.fail("Initialization errors: %s\n%s" % (errors, pb))
|
||||
if normalize_numbers:
|
||||
_normalize_number_fields(pb)
|
||||
|
||||
a_str = text_format.MessageToString(a, descriptor_pool=pool)
|
||||
b_str = text_format.MessageToString(b, descriptor_pool=pool)
|
||||
|
||||
# Some Python versions would perform regular diff instead of multi-line
|
||||
# diff if string is longer than 2**16. We substitute this behavior
|
||||
# with a call to unified_diff instead to have easier-to-read diffs.
|
||||
# For context, see: https://bugs.python.org/issue11763.
|
||||
if len(a_str) < 2**16 and len(b_str) < 2**16:
|
||||
self.assertMultiLineEqual(a_str, b_str, msg=msg)
|
||||
else:
|
||||
diff = "".join(
|
||||
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
|
||||
if diff:
|
||||
self.fail("%s :\n%s" % (msg, diff))
|
||||
|
||||
|
||||
def _normalize_number_fields(pb):
|
||||
"""Normalizes types and precisions of number fields in a protocol buffer.
|
||||
|
||||
Due to subtleties in the python protocol buffer implementation, it is possible
|
||||
for values to have different types and precision depending on whether they
|
||||
were set and retrieved directly or deserialized from a protobuf. This function
|
||||
normalizes integer values to ints and longs based on width, 32-bit floats to
|
||||
five digits of precision to account for python always storing them as 64-bit,
|
||||
and ensures doubles are floating point for when they're set to integers.
|
||||
Modifies pb in place. Recurses into nested objects. https://github.com/tensorf
|
||||
low/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
|
||||
|
||||
Args:
|
||||
pb: proto2 message.
|
||||
|
||||
Returns:
|
||||
the given pb, modified in place.
|
||||
"""
|
||||
for desc, values in pb.ListFields():
|
||||
is_repeated = True
|
||||
if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
|
||||
is_repeated = False
|
||||
values = [values]
|
||||
|
||||
normalized_values = None
|
||||
|
||||
# We force 32-bit values to int and 64-bit values to long to make
|
||||
# alternate implementations where the distinction is more significant
|
||||
# (e.g. the C++ implementation) simpler.
|
||||
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
|
||||
descriptor.FieldDescriptor.TYPE_UINT64,
|
||||
descriptor.FieldDescriptor.TYPE_SINT64):
|
||||
normalized_values = [int(x) for x in values]
|
||||
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
|
||||
descriptor.FieldDescriptor.TYPE_UINT32,
|
||||
descriptor.FieldDescriptor.TYPE_SINT32,
|
||||
descriptor.FieldDescriptor.TYPE_ENUM):
|
||||
normalized_values = [int(x) for x in values]
|
||||
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
||||
normalized_values = [round(x, 5) for x in values]
|
||||
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
|
||||
normalized_values = [round(float(x), 6) for x in values]
|
||||
|
||||
if normalized_values is not None:
|
||||
if is_repeated:
|
||||
pb.ClearField(desc.name)
|
||||
getattr(pb, desc.name).extend(normalized_values)
|
||||
else:
|
||||
setattr(pb, desc.name, normalized_values[0])
|
||||
|
||||
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
|
||||
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
|
||||
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
|
||||
desc.message_type.has_options and
|
||||
desc.message_type.GetOptions().map_entry):
|
||||
# This is a map, only recurse if the values have a message type.
|
||||
if (desc.message_type.fields_by_number[2].type ==
|
||||
descriptor.FieldDescriptor.TYPE_MESSAGE):
|
||||
for v in six.itervalues(values):
|
||||
_normalize_number_fields(v)
|
||||
else:
|
||||
for v in values:
|
||||
# recursive step
|
||||
_normalize_number_fields(v)
|
||||
|
||||
return pb
|
||||
|
|
|
@ -53,11 +53,6 @@ _SCORE_THRESHOLD = 0.5
|
|||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
# TODO: Port assertProtoEquals
|
||||
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
|
||||
pass
|
||||
|
||||
|
||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
_Classifications(
|
||||
|
@ -77,22 +72,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
|||
categories=[
|
||||
_Category(
|
||||
index=934,
|
||||
score=0.7939587831497192,
|
||||
score=0.793959,
|
||||
display_name='',
|
||||
category_name='cheeseburger'),
|
||||
_Category(
|
||||
index=932,
|
||||
score=0.02739289402961731,
|
||||
score=0.0273929,
|
||||
display_name='',
|
||||
category_name='bagel'),
|
||||
_Category(
|
||||
index=925,
|
||||
score=0.01934075355529785,
|
||||
score=0.0193408,
|
||||
display_name='',
|
||||
category_name='guacamole'),
|
||||
_Category(
|
||||
index=963,
|
||||
score=0.006327860057353973,
|
||||
score=0.00632786,
|
||||
display_name='',
|
||||
category_name='meat loaf')
|
||||
],
|
||||
|
@ -111,7 +106,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
|||
categories=[
|
||||
_Category(
|
||||
index=806,
|
||||
score=0.9965274930000305,
|
||||
score=0.996527,
|
||||
display_name='',
|
||||
category_name='soccer ball')
|
||||
],
|
||||
|
@ -189,8 +184,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
# Closes the classifier explicitly when the classifier is not used in
|
||||
# a context.
|
||||
classifier.close()
|
||||
|
@ -217,8 +212,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
|
||||
def test_classify_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
@ -235,8 +230,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(image_result.to_pb2(),
|
||||
_generate_soccer_ball_results(0).to_pb2())
|
||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||
_generate_soccer_ball_results(0).to_pb2())
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
|
@ -404,8 +399,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp)
|
||||
_assert_proto_equals(classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self, classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2())
|
||||
|
||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
|
@ -423,8 +419,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, image_processing_options)
|
||||
self.assertEqual(classification_result,
|
||||
_generate_soccer_ball_results(timestamp))
|
||||
test_utils.assert_proto_equals(
|
||||
self, classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||
|
||||
def test_calling_classify_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
|
@ -466,8 +463,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
_assert_proto_equals(result.to_pb2(),
|
||||
expected_result_fn(timestamp_ms).to_pb2())
|
||||
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||
expected_result_fn(timestamp_ms).to_pb2())
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
|
@ -496,8 +493,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
_assert_proto_equals(result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self, result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
|
|
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -50,6 +50,7 @@ exports_files([
|
|||
"score_calibration.txt",
|
||||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
"labels.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
|
@ -91,5 +92,6 @@ filegroup(
|
|||
"score_calibration.txt",
|
||||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
],
|
||||
)
|
||||
|
|
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "ScoreThresholdingOptions",
|
||||
"options": {
|
||||
"global_score_threshold": 0.5
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -58,7 +58,7 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
|
|||
absl::Status FillMultiStreamTimeSeriesHeaderIfValid(
|
||||
const Packet& header_packet, MultiStreamTimeSeriesHeader* header);
|
||||
|
||||
// Returnsabsl::Status::OK iff options contains an extension of type
|
||||
// Returns absl::Status::OK iff options contains an extension of type
|
||||
// OptionsClass.
|
||||
template <typename OptionsClass>
|
||||
absl::Status HasOptionsExtension(const CalculatorOptions& options) {
|
||||
|
@ -75,7 +75,7 @@ absl::Status HasOptionsExtension(const CalculatorOptions& options) {
|
|||
return absl::InvalidArgumentError(error_message);
|
||||
}
|
||||
|
||||
// Returnsabsl::Status::OK if the shape of 'matrix' is consistent
|
||||
// Returns absl::Status::OK if the shape of 'matrix' is consistent
|
||||
// with the num_samples and num_channels fields present in 'header'.
|
||||
// The corresponding matrix dimensions of unset header fields are
|
||||
// ignored, so e.g. an empty header (which is not valid according to
|
||||
|
|
|
@ -1323,10 +1323,9 @@ void MotionBox::GetSpatialGaussWeights(const MotionBoxState& box_state,
|
|||
const float space_sigma_x = std::max(
|
||||
options_.spatial_sigma(), box_state.inlier_width() * inv_box_domain.x() *
|
||||
0.5f * box_state.prior_weight() / 1.65f);
|
||||
const float space_sigma_y = options_.spatial_sigma();
|
||||
std::max(options_.spatial_sigma(), box_state.inlier_height() *
|
||||
inv_box_domain.y() * 0.5f *
|
||||
box_state.prior_weight() / 1.65f);
|
||||
const float space_sigma_y = std::max(
|
||||
options_.spatial_sigma(), box_state.inlier_height() * inv_box_domain.y() *
|
||||
0.5f * box_state.prior_weight() / 1.65f);
|
||||
|
||||
*spatial_gauss_x = -0.5f / (space_sigma_x * space_sigma_x);
|
||||
*spatial_gauss_y = -0.5f / (space_sigma_y * space_sigma_y);
|
||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -694,6 +694,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_score_thresholding_meta_json",
|
||||
sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
|
||||
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",
|
||||
|
|
Loading…
Reference in New Issue
Block a user