Merge branch 'master' into gesture-recognizer-python
This commit is contained in:
commit
3a2f30185f
|
@ -328,6 +328,7 @@ cc_library(
|
||||||
":concatenate_vector_calculator_cc_proto",
|
":concatenate_vector_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
@ -344,6 +345,7 @@ cc_test(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
|
||||||
};
|
};
|
||||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
|
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
|
||||||
|
|
||||||
|
class ConcatenateClassificationListCalculator
|
||||||
|
: public ConcatenateListsCalculator<Classification, ClassificationList> {
|
||||||
|
protected:
|
||||||
|
int ListSize(const ClassificationList& list) const override {
|
||||||
|
return list.classification_size();
|
||||||
|
}
|
||||||
|
const Classification GetItem(const ClassificationList& list,
|
||||||
|
int idx) const override {
|
||||||
|
return list.classification(idx);
|
||||||
|
}
|
||||||
|
Classification* AddItem(ClassificationList& list) const override {
|
||||||
|
return list.add_classification();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddInputClassificationLists(
|
||||||
|
const std::vector<ClassificationList>& input_classifications_vec,
|
||||||
|
int64 timestamp, CalculatorRunner* runner) {
|
||||||
|
for (int i = 0; i < input_classifications_vec.size(); ++i) {
|
||||||
|
runner->MutableInputs()->Index(i).packets.push_back(
|
||||||
|
MakePacket<ClassificationList>(input_classifications_vec[i])
|
||||||
|
.At(Timestamp(timestamp)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
||||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||||
/*options_string=*/"", /*num_inputs=*/3,
|
/*options_string=*/"", /*num_inputs=*/3,
|
||||||
|
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
|
||||||
EXPECT_EQ(0, outputs.size());
|
EXPECT_EQ(0, outputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
|
||||||
|
CalculatorRunner runner("ConcatenateClassificationListCalculator",
|
||||||
|
/*options_string=*/
|
||||||
|
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||||
|
"{only_emit_if_all_present: true}",
|
||||||
|
/*num_inputs=*/2,
|
||||||
|
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||||
|
|
||||||
|
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||||
|
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||||
|
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||||
|
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||||
|
)pb");
|
||||||
|
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||||
|
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||||
|
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||||
|
)pb");
|
||||||
|
std::vector<ClassificationList> inputs = {input_0, input_1};
|
||||||
|
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||||
|
auto result = outputs[0].Get<ClassificationList>();
|
||||||
|
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||||
|
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||||
|
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||||
|
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||||
|
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||||
|
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||||
|
)pb"),
|
||||||
|
EqualsProto(result));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
37
mediapipe/tasks/python/audio/core/BUILD
Normal file
37
mediapipe/tasks/python/audio/core/BUILD
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "audio_task_running_mode",
|
||||||
|
srcs = ["audio_task_running_mode.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "base_audio_task_api",
|
||||||
|
srcs = [
|
||||||
|
"base_audio_task_api.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":audio_task_running_mode",
|
||||||
|
"//mediapipe/framework:calculator_py_pb2",
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
16
mediapipe/tasks/python/audio/core/__init__.py
Normal file
16
mediapipe/tasks/python/audio/core/__init__.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
"""Copyright 2022 The MediaPipe Authors.
|
||||||
|
|
||||||
|
All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
29
mediapipe/tasks/python/audio/core/audio_task_running_mode.py
Normal file
29
mediapipe/tasks/python/audio/core/audio_task_running_mode.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The running mode of MediaPipe Audio Tasks."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class AudioTaskRunningMode(enum.Enum):
|
||||||
|
"""MediaPipe audio task running mode.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
AUDIO_CLIPS: The mode for running a mediapipe audio task on independent
|
||||||
|
audio clips.
|
||||||
|
AUDIO_STREAM: The mode for running a mediapipe audio task on an audio
|
||||||
|
stream, such as from microphone.
|
||||||
|
"""
|
||||||
|
AUDIO_CLIPS = 'AUDIO_CLIPS'
|
||||||
|
AUDIO_STREAM = 'AUDIO_STREAM'
|
123
mediapipe/tasks/python/audio/core/base_audio_task_api.py
Normal file
123
mediapipe/tasks/python/audio/core/base_audio_task_api.py
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe audio task base api."""
|
||||||
|
|
||||||
|
from typing import Callable, Mapping, Optional
|
||||||
|
|
||||||
|
from mediapipe.framework import calculator_pb2
|
||||||
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
|
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_TaskRunner = task_runner_module.TaskRunner
|
||||||
|
_Packet = packet_module.Packet
|
||||||
|
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAudioTaskApi(object):
|
||||||
|
"""The base class of the user-facing mediapipe audio task api classes."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph_config: calculator_pb2.CalculatorGraphConfig,
|
||||||
|
running_mode: _RunningMode,
|
||||||
|
packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]],
|
||||||
|
None]] = None
|
||||||
|
) -> None:
|
||||||
|
"""Initializes the `BaseAudioTaskApi` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_config: The mediapipe audio task graph config proto.
|
||||||
|
running_mode: The running mode of the mediapipe audio task.
|
||||||
|
packet_callback: The optional packet callback for getting results
|
||||||
|
asynchronously in the audio stream mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: The packet callback is not properly set based on the task's
|
||||||
|
running mode.
|
||||||
|
"""
|
||||||
|
if running_mode == _RunningMode.AUDIO_STREAM:
|
||||||
|
if packet_callback is None:
|
||||||
|
raise ValueError(
|
||||||
|
'The audio task is in audio stream mode, a user-defined result '
|
||||||
|
'callback must be provided.')
|
||||||
|
elif packet_callback:
|
||||||
|
raise ValueError(
|
||||||
|
'The audio task is in audio clips mode, a user-defined result '
|
||||||
|
'callback should not be provided.')
|
||||||
|
self._runner = _TaskRunner.create(graph_config, packet_callback)
|
||||||
|
self._running_mode = running_mode
|
||||||
|
|
||||||
|
def _process_audio_clip(
|
||||||
|
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
|
||||||
|
"""A synchronous method to process independent audio clips.
|
||||||
|
|
||||||
|
The call blocks the current thread until a failure status or a successful
|
||||||
|
result is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dict contains (input stream name, data packet) pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict contains (output stream name, data packet) pairs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the task's running mode is not set to audio clips mode.
|
||||||
|
"""
|
||||||
|
if self._running_mode != _RunningMode.AUDIO_CLIPS:
|
||||||
|
raise ValueError(
|
||||||
|
'Task is not initialized with the audio clips mode. Current running mode:'
|
||||||
|
+ self._running_mode.name)
|
||||||
|
return self._runner.process(inputs)
|
||||||
|
|
||||||
|
def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
|
||||||
|
"""An asynchronous method to send audio stream data to the runner.
|
||||||
|
|
||||||
|
The results will be available in the user-defined results callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dict contains (input stream name, data packet) pairs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the task's running mode is not set to the audio stream
|
||||||
|
mode.
|
||||||
|
"""
|
||||||
|
if self._running_mode != _RunningMode.AUDIO_STREAM:
|
||||||
|
raise ValueError(
|
||||||
|
'Task is not initialized with the audio stream mode. Current running mode:'
|
||||||
|
+ self._running_mode.name)
|
||||||
|
self._runner.send(inputs)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Shuts down the mediapipe audio task instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the mediapipe audio task failed to close.
|
||||||
|
"""
|
||||||
|
self._runner.close()
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def __enter__(self):
|
||||||
|
"""Return `self` upon entering the runtime context."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
|
||||||
|
"""Shuts down the mediapipe audio task instance on exit of the context manager.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the mediapipe audio task failed to close.
|
||||||
|
"""
|
||||||
|
self.close()
|
|
@ -197,6 +197,37 @@ class ScoreCalibrationMd:
|
||||||
self._FILE_TYPE)
|
self._FILE_TYPE)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreThresholdingMd:
|
||||||
|
"""A container for score thresholding [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, global_score_threshold: float) -> None:
|
||||||
|
"""Creates a ScoreThresholdingMd object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_score_threshold: The recommended global threshold below which
|
||||||
|
results are considered low-confidence and should be filtered out.
|
||||||
|
"""
|
||||||
|
self._global_score_threshold = global_score_threshold
|
||||||
|
|
||||||
|
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||||
|
"""Creates the score thresholding metadata based on the information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Flatbuffers Python object of the score thresholding metadata.
|
||||||
|
"""
|
||||||
|
score_thresholding = _metadata_fb.ProcessUnitT()
|
||||||
|
score_thresholding.optionsType = (
|
||||||
|
_metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions)
|
||||||
|
options = _metadata_fb.ScoreThresholdingOptionsT()
|
||||||
|
options.globalScoreThreshold = self._global_score_threshold
|
||||||
|
score_thresholding.options = options
|
||||||
|
return score_thresholding
|
||||||
|
|
||||||
|
|
||||||
class TensorMd:
|
class TensorMd:
|
||||||
"""A container for common tensor metadata information.
|
"""A container for common tensor metadata information.
|
||||||
|
|
||||||
|
@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd):
|
||||||
tensor.
|
tensor.
|
||||||
score_calibration_md: information of the score calibration operation [2] in
|
score_calibration_md: information of the score calibration operation [2] in
|
||||||
the classification tensor.
|
the classification tensor.
|
||||||
|
score_thresholding_md: information of the score thresholding [3] in the
|
||||||
|
classification tensor.
|
||||||
[1]:
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
[2]:
|
[2]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||||
|
[3]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Min and max float values for classification results.
|
# Min and max float values for classification results.
|
||||||
_MIN_FLOAT = 0.0
|
_MIN_FLOAT = 0.0
|
||||||
_MAX_FLOAT = 1.0
|
_MAX_FLOAT = 1.0
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
label_files: Optional[List[LabelFileMd]] = None,
|
label_files: Optional[List[LabelFileMd]] = None,
|
||||||
tensor_type: Optional[int] = None,
|
tensor_type: Optional[int] = None,
|
||||||
score_calibration_md: Optional[ScoreCalibrationMd] = None,
|
score_calibration_md: Optional[ScoreCalibrationMd] = None,
|
||||||
tensor_name: Optional[str] = None) -> None:
|
tensor_name: Optional[str] = None,
|
||||||
|
score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None:
|
||||||
"""Initializes the instance of ClassificationTensorMd.
|
"""Initializes the instance of ClassificationTensorMd.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd):
|
||||||
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
|
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
|
||||||
is used to locate the corresponding classification tensor and decide the
|
is used to locate the corresponding classification tensor and decide the
|
||||||
order of the tensor metadata [4] when populating model metadata.
|
order of the tensor metadata [4] when populating model metadata.
|
||||||
|
score_thresholding_md: information of the score thresholding [5] in the
|
||||||
|
classification tensor.
|
||||||
[1]:
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
[2]:
|
[2]:
|
||||||
|
@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd):
|
||||||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||||
[4]:
|
[4]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||||
|
[5]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
"""
|
"""
|
||||||
self.score_calibration_md = score_calibration_md
|
self.score_calibration_md = score_calibration_md
|
||||||
|
self.score_thresholding_md = score_thresholding_md
|
||||||
|
|
||||||
if tensor_type is _schema_fb.TensorType.UINT8:
|
if tensor_type is _schema_fb.TensorType.UINT8:
|
||||||
min_values = [_MIN_UINT8]
|
min_values = [_MIN_UINT8]
|
||||||
|
@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd):
|
||||||
tensor_metadata.processUnits = [
|
tensor_metadata.processUnits = [
|
||||||
self.score_calibration_md.create_metadata()
|
self.score_calibration_md.create_metadata()
|
||||||
]
|
]
|
||||||
|
if self.score_thresholding_md:
|
||||||
|
if tensor_metadata.processUnits:
|
||||||
|
tensor_metadata.processUnits.append(
|
||||||
|
self.score_thresholding_md.create_metadata())
|
||||||
|
else:
|
||||||
|
tensor_metadata.processUnits = [
|
||||||
|
self.score_thresholding_md.create_metadata()
|
||||||
|
]
|
||||||
return tensor_metadata
|
return tensor_metadata
|
||||||
|
|
|
@ -70,6 +70,18 @@ class LabelItem:
|
||||||
locale: Optional[str] = None
|
locale: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ScoreThresholding:
|
||||||
|
"""Parameters to performs thresholding on output tensor values [1].
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
global_score_threshold: The recommended global threshold below which results
|
||||||
|
are considered low-confidence and should be filtered out. [1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
|
"""
|
||||||
|
global_score_threshold: float
|
||||||
|
|
||||||
|
|
||||||
class Labels(object):
|
class Labels(object):
|
||||||
"""Simple container holding classification labels of a particular tensor.
|
"""Simple container holding classification labels of a particular tensor.
|
||||||
|
|
||||||
|
@ -407,6 +419,7 @@ class MetadataWriter(object):
|
||||||
self,
|
self,
|
||||||
labels: Optional[Labels] = None,
|
labels: Optional[Labels] = None,
|
||||||
score_calibration: Optional[ScoreCalibration] = None,
|
score_calibration: Optional[ScoreCalibration] = None,
|
||||||
|
score_thresholding: Optional[ScoreThresholding] = None,
|
||||||
name: str = _OUTPUT_CLASSIFICATION_NAME,
|
name: str = _OUTPUT_CLASSIFICATION_NAME,
|
||||||
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
|
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
|
||||||
) -> 'MetadataWriter':
|
) -> 'MetadataWriter':
|
||||||
|
@ -423,6 +436,7 @@ class MetadataWriter(object):
|
||||||
Args:
|
Args:
|
||||||
labels: an instance of Labels helper class.
|
labels: an instance of Labels helper class.
|
||||||
score_calibration: an instance of ScoreCalibration helper class.
|
score_calibration: an instance of ScoreCalibration helper class.
|
||||||
|
score_thresholding: an instance of ScoreThresholding.
|
||||||
name: Metadata name of the tensor. Note that this is different from tensor
|
name: Metadata name of the tensor. Note that this is different from tensor
|
||||||
name in the flatbuffer.
|
name in the flatbuffer.
|
||||||
description: human readable description of what the output is.
|
description: human readable description of what the output is.
|
||||||
|
@ -437,6 +451,10 @@ class MetadataWriter(object):
|
||||||
default_score=score_calibration.default_score,
|
default_score=score_calibration.default_score,
|
||||||
file_path=self._export_calibration_file('score_calibration.txt',
|
file_path=self._export_calibration_file('score_calibration.txt',
|
||||||
score_calibration.parameters))
|
score_calibration.parameters))
|
||||||
|
score_thresholding_md = None
|
||||||
|
if score_thresholding:
|
||||||
|
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||||
|
score_thresholding.global_score_threshold)
|
||||||
|
|
||||||
label_files = None
|
label_files = None
|
||||||
if labels:
|
if labels:
|
||||||
|
@ -453,6 +471,7 @@ class MetadataWriter(object):
|
||||||
label_files=label_files,
|
label_files=label_files,
|
||||||
tensor_type=self._output_tensor_type(len(self._output_mds)),
|
tensor_type=self._output_tensor_type(len(self._output_mds)),
|
||||||
score_calibration_md=calibration_md,
|
score_calibration_md=calibration_md,
|
||||||
|
score_thresholding_md=score_thresholding_md,
|
||||||
)
|
)
|
||||||
self._output_mds.append(output_md)
|
self._output_mds.append(output_md)
|
||||||
return self
|
return self
|
||||||
|
@ -545,4 +564,3 @@ class MetadataWriterBase:
|
||||||
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||||
"""
|
"""
|
||||||
return self.writer.populate()
|
return self.writer.populate()
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Helper methods for writing metadata into TFLite models."""
|
"""Helper methods for writing metadata into TFLite models."""
|
||||||
|
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
import zipfile
|
||||||
|
|
||||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||||
|
|
||||||
|
@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
|
||||||
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
||||||
# subgraphs for acceleration evaluation purpose.
|
# subgraphs for acceleration evaluation purpose.
|
||||||
return model.Subgraphs(0)
|
return model.Subgraphs(0)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_asset_bundle(input_models: Dict[str, bytes],
|
||||||
|
output_path: str) -> None:
|
||||||
|
"""Creates the model asset bundle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_models: A dict of input models with key as the model file name and
|
||||||
|
value as the model content.
|
||||||
|
output_path: The output file path to save the model asset bundle.
|
||||||
|
"""
|
||||||
|
if not input_models or len(input_models) < 2:
|
||||||
|
raise ValueError("Needs at least two input models for model asset bundle.")
|
||||||
|
|
||||||
|
with zipfile.ZipFile(output_path, mode="w") as zf:
|
||||||
|
for file_name, file_buffer in input_models.items():
|
||||||
|
zf.writestr(file_name, file_buffer)
|
||||||
|
|
|
@ -27,5 +27,8 @@ py_library(
|
||||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
|
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
|
||||||
"//mediapipe/tasks:internal",
|
"//mediapipe/tasks:internal",
|
||||||
],
|
],
|
||||||
deps = ["//mediapipe/python:_framework_bindings"],
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"@com_google_protobuf//:protobuf_python",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase):
|
||||||
malformed_calibration_file)
|
malformed_calibration_file)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreThresholdingMdTest(absltest.TestCase):
|
||||||
|
_DEFAULT_GLOBAL_THRESHOLD = 0.5
|
||||||
|
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
|
"score_thresholding_meta.json")
|
||||||
|
|
||||||
|
def test_create_metadata_should_succeed(self):
|
||||||
|
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||||
|
global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD)
|
||||||
|
|
||||||
|
score_thresholding_metadata = score_thresholding_md.create_metadata()
|
||||||
|
|
||||||
|
metadata_json = _metadata.convert_to_json(
|
||||||
|
_create_dummy_model_metadata_with_process_uint(
|
||||||
|
score_thresholding_metadata))
|
||||||
|
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
||||||
def _create_dummy_model_metadata_with_tensor(
|
def _create_dummy_model_metadata_with_tensor(
|
||||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||||
# Create a dummy model using the tensor metadata.
|
# Create a dummy model using the tensor metadata.
|
||||||
|
|
|
@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
def test_add_classification_output_with_score_thresholding(self):
|
||||||
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
|
self.image_classifier_model_buffer)
|
||||||
|
writer.add_classification_output(
|
||||||
|
labels=metadata_writer.Labels().add(['a', 'b', 'c']),
|
||||||
|
score_thresholding=metadata_writer.ScoreThresholding(
|
||||||
|
global_score_threshold=0.5))
|
||||||
|
_, metadata_json = writer.populate()
|
||||||
|
print(metadata_json)
|
||||||
|
self.assertJsonEqual(
|
||||||
|
metadata_json, """{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "input"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "ScoreThresholdingOptions",
|
||||||
|
"options": {
|
||||||
|
"global_score_threshold": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.0.0"
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -13,9 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Test util for MediaPipe Tasks."""
|
"""Test util for MediaPipe Tasks."""
|
||||||
|
|
||||||
|
import difflib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
import six
|
||||||
|
|
||||||
|
from google.protobuf import descriptor
|
||||||
|
from google.protobuf import descriptor_pool
|
||||||
|
from google.protobuf import text_format
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||||
|
@ -53,3 +59,126 @@ def create_calibration_file(file_dir: str,
|
||||||
with open(calibration_file, mode="w") as file:
|
with open(calibration_file, mode="w") as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
return calibration_file
|
return calibration_file
|
||||||
|
|
||||||
|
|
||||||
|
def assert_proto_equals(self,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
check_initialized=True,
|
||||||
|
normalize_numbers=True,
|
||||||
|
msg=None):
|
||||||
|
"""assert_proto_equals() is useful for unit tests.
|
||||||
|
|
||||||
|
It produces much more helpful output than assertEqual() for proto2 messages.
|
||||||
|
Fails with a useful error if a and b aren't equal. Comparison of repeated
|
||||||
|
fields matches the semantics of unittest.TestCase.assertEqual(), ie order and
|
||||||
|
extra duplicates fields matter.
|
||||||
|
|
||||||
|
This is a fork of https://github.com/tensorflow/tensorflow/blob/
|
||||||
|
master/tensorflow/python/util/protobuf/compare.py#L73. We use slightly
|
||||||
|
different rounding cutoffs to support Mac usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: absltest.testing.parameterized.TestCase
|
||||||
|
a: proto2 PB instance, or text string representing one.
|
||||||
|
b: proto2 PB instance -- message.Message or subclass thereof.
|
||||||
|
check_initialized: boolean, whether to fail if either a or b isn't
|
||||||
|
initialized.
|
||||||
|
normalize_numbers: boolean, whether to normalize types and precision of
|
||||||
|
numbers before comparison.
|
||||||
|
msg: if specified, is used as the error message on failure.
|
||||||
|
"""
|
||||||
|
pool = descriptor_pool.Default()
|
||||||
|
if isinstance(a, six.string_types):
|
||||||
|
a = text_format.Parse(a, b.__class__(), descriptor_pool=pool)
|
||||||
|
|
||||||
|
for pb in a, b:
|
||||||
|
if check_initialized:
|
||||||
|
errors = pb.FindInitializationErrors()
|
||||||
|
if errors:
|
||||||
|
self.fail("Initialization errors: %s\n%s" % (errors, pb))
|
||||||
|
if normalize_numbers:
|
||||||
|
_normalize_number_fields(pb)
|
||||||
|
|
||||||
|
a_str = text_format.MessageToString(a, descriptor_pool=pool)
|
||||||
|
b_str = text_format.MessageToString(b, descriptor_pool=pool)
|
||||||
|
|
||||||
|
# Some Python versions would perform regular diff instead of multi-line
|
||||||
|
# diff if string is longer than 2**16. We substitute this behavior
|
||||||
|
# with a call to unified_diff instead to have easier-to-read diffs.
|
||||||
|
# For context, see: https://bugs.python.org/issue11763.
|
||||||
|
if len(a_str) < 2**16 and len(b_str) < 2**16:
|
||||||
|
self.assertMultiLineEqual(a_str, b_str, msg=msg)
|
||||||
|
else:
|
||||||
|
diff = "".join(
|
||||||
|
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
|
||||||
|
if diff:
|
||||||
|
self.fail("%s :\n%s" % (msg, diff))
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_number_fields(pb):
|
||||||
|
"""Normalizes types and precisions of number fields in a protocol buffer.
|
||||||
|
|
||||||
|
Due to subtleties in the python protocol buffer implementation, it is possible
|
||||||
|
for values to have different types and precision depending on whether they
|
||||||
|
were set and retrieved directly or deserialized from a protobuf. This function
|
||||||
|
normalizes integer values to ints and longs based on width, 32-bit floats to
|
||||||
|
five digits of precision to account for python always storing them as 64-bit,
|
||||||
|
and ensures doubles are floating point for when they're set to integers.
|
||||||
|
Modifies pb in place. Recurses into nested objects. https://github.com/tensorf
|
||||||
|
low/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pb: proto2 message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the given pb, modified in place.
|
||||||
|
"""
|
||||||
|
for desc, values in pb.ListFields():
|
||||||
|
is_repeated = True
|
||||||
|
if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
|
||||||
|
is_repeated = False
|
||||||
|
values = [values]
|
||||||
|
|
||||||
|
normalized_values = None
|
||||||
|
|
||||||
|
# We force 32-bit values to int and 64-bit values to long to make
|
||||||
|
# alternate implementations where the distinction is more significant
|
||||||
|
# (e.g. the C++ implementation) simpler.
|
||||||
|
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
|
||||||
|
descriptor.FieldDescriptor.TYPE_UINT64,
|
||||||
|
descriptor.FieldDescriptor.TYPE_SINT64):
|
||||||
|
normalized_values = [int(x) for x in values]
|
||||||
|
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
|
||||||
|
descriptor.FieldDescriptor.TYPE_UINT32,
|
||||||
|
descriptor.FieldDescriptor.TYPE_SINT32,
|
||||||
|
descriptor.FieldDescriptor.TYPE_ENUM):
|
||||||
|
normalized_values = [int(x) for x in values]
|
||||||
|
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
||||||
|
normalized_values = [round(x, 5) for x in values]
|
||||||
|
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
|
||||||
|
normalized_values = [round(float(x), 6) for x in values]
|
||||||
|
|
||||||
|
if normalized_values is not None:
|
||||||
|
if is_repeated:
|
||||||
|
pb.ClearField(desc.name)
|
||||||
|
getattr(pb, desc.name).extend(normalized_values)
|
||||||
|
else:
|
||||||
|
setattr(pb, desc.name, normalized_values[0])
|
||||||
|
|
||||||
|
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
|
||||||
|
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
|
||||||
|
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
|
||||||
|
desc.message_type.has_options and
|
||||||
|
desc.message_type.GetOptions().map_entry):
|
||||||
|
# This is a map, only recurse if the values have a message type.
|
||||||
|
if (desc.message_type.fields_by_number[2].type ==
|
||||||
|
descriptor.FieldDescriptor.TYPE_MESSAGE):
|
||||||
|
for v in six.itervalues(values):
|
||||||
|
_normalize_number_fields(v)
|
||||||
|
else:
|
||||||
|
for v in values:
|
||||||
|
# recursive step
|
||||||
|
_normalize_number_fields(v)
|
||||||
|
|
||||||
|
return pb
|
||||||
|
|
|
@ -53,11 +53,6 @@ _SCORE_THRESHOLD = 0.5
|
||||||
_MAX_RESULTS = 3
|
_MAX_RESULTS = 3
|
||||||
|
|
||||||
|
|
||||||
# TODO: Port assertProtoEquals
|
|
||||||
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
return _ClassificationResult(classifications=[
|
return _ClassificationResult(classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
|
@ -77,22 +72,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=934,
|
index=934,
|
||||||
score=0.7939587831497192,
|
score=0.793959,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='cheeseburger'),
|
category_name='cheeseburger'),
|
||||||
_Category(
|
_Category(
|
||||||
index=932,
|
index=932,
|
||||||
score=0.02739289402961731,
|
score=0.0273929,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='bagel'),
|
category_name='bagel'),
|
||||||
_Category(
|
_Category(
|
||||||
index=925,
|
index=925,
|
||||||
score=0.01934075355529785,
|
score=0.0193408,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='guacamole'),
|
category_name='guacamole'),
|
||||||
_Category(
|
_Category(
|
||||||
index=963,
|
index=963,
|
||||||
score=0.006327860057353973,
|
score=0.00632786,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='meat loaf')
|
category_name='meat loaf')
|
||||||
],
|
],
|
||||||
|
@ -111,7 +106,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=806,
|
index=806,
|
||||||
score=0.9965274930000305,
|
score=0.996527,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='soccer ball')
|
category_name='soccer ball')
|
||||||
],
|
],
|
||||||
|
@ -189,7 +184,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(image_result.to_pb2(),
|
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||||
expected_classification_result.to_pb2())
|
expected_classification_result.to_pb2())
|
||||||
# Closes the classifier explicitly when the classifier is not used in
|
# Closes the classifier explicitly when the classifier is not used in
|
||||||
# a context.
|
# a context.
|
||||||
|
@ -217,7 +212,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(image_result.to_pb2(),
|
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||||
expected_classification_result.to_pb2())
|
expected_classification_result.to_pb2())
|
||||||
|
|
||||||
def test_classify_succeeds_with_region_of_interest(self):
|
def test_classify_succeeds_with_region_of_interest(self):
|
||||||
|
@ -235,7 +230,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(test_image, image_processing_options)
|
image_result = classifier.classify(test_image, image_processing_options)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(image_result.to_pb2(),
|
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(0).to_pb2())
|
_generate_soccer_ball_results(0).to_pb2())
|
||||||
|
|
||||||
def test_score_threshold_option(self):
|
def test_score_threshold_option(self):
|
||||||
|
@ -404,7 +399,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
self.test_image, timestamp)
|
self.test_image, timestamp)
|
||||||
_assert_proto_equals(classification_result.to_pb2(),
|
test_utils.assert_proto_equals(
|
||||||
|
self, classification_result.to_pb2(),
|
||||||
_generate_burger_results(timestamp).to_pb2())
|
_generate_burger_results(timestamp).to_pb2())
|
||||||
|
|
||||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||||
|
@ -423,8 +419,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
test_image, timestamp, image_processing_options)
|
test_image, timestamp, image_processing_options)
|
||||||
self.assertEqual(classification_result,
|
test_utils.assert_proto_equals(
|
||||||
_generate_soccer_ball_results(timestamp))
|
self, classification_result.to_pb2(),
|
||||||
|
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||||
|
|
||||||
def test_calling_classify_in_live_stream_mode(self):
|
def test_calling_classify_in_live_stream_mode(self):
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
|
@ -466,7 +463,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
_assert_proto_equals(result.to_pb2(),
|
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||||
expected_result_fn(timestamp_ms).to_pb2())
|
expected_result_fn(timestamp_ms).to_pb2())
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(output_image.numpy_view(),
|
np.array_equal(output_image.numpy_view(),
|
||||||
|
@ -496,7 +493,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
_assert_proto_equals(result.to_pb2(),
|
test_utils.assert_proto_equals(
|
||||||
|
self, result.to_pb2(),
|
||||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||||
self.assertEqual(output_image.width, test_image.width)
|
self.assertEqual(output_image.width, test_image.width)
|
||||||
self.assertEqual(output_image.height, test_image.height)
|
self.assertEqual(output_image.height, test_image.height)
|
||||||
|
|
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -50,6 +50,7 @@ exports_files([
|
||||||
"score_calibration.txt",
|
"score_calibration.txt",
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
|
"score_thresholding_meta.json",
|
||||||
"labels.txt",
|
"labels.txt",
|
||||||
"mobilenet_v2_1.0_224.json",
|
"mobilenet_v2_1.0_224.json",
|
||||||
"mobilenet_v2_1.0_224_quant.json",
|
"mobilenet_v2_1.0_224_quant.json",
|
||||||
|
@ -91,5 +92,6 @@ filegroup(
|
||||||
"score_calibration.txt",
|
"score_calibration.txt",
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
|
"score_thresholding_meta.json",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "ScoreThresholdingOptions",
|
||||||
|
"options": {
|
||||||
|
"global_score_threshold": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -1323,10 +1323,9 @@ void MotionBox::GetSpatialGaussWeights(const MotionBoxState& box_state,
|
||||||
const float space_sigma_x = std::max(
|
const float space_sigma_x = std::max(
|
||||||
options_.spatial_sigma(), box_state.inlier_width() * inv_box_domain.x() *
|
options_.spatial_sigma(), box_state.inlier_width() * inv_box_domain.x() *
|
||||||
0.5f * box_state.prior_weight() / 1.65f);
|
0.5f * box_state.prior_weight() / 1.65f);
|
||||||
const float space_sigma_y = options_.spatial_sigma();
|
const float space_sigma_y = std::max(
|
||||||
std::max(options_.spatial_sigma(), box_state.inlier_height() *
|
options_.spatial_sigma(), box_state.inlier_height() * inv_box_domain.y() *
|
||||||
inv_box_domain.y() * 0.5f *
|
0.5f * box_state.prior_weight() / 1.65f);
|
||||||
box_state.prior_weight() / 1.65f);
|
|
||||||
|
|
||||||
*spatial_gauss_x = -0.5f / (space_sigma_x * space_sigma_x);
|
*spatial_gauss_x = -0.5f / (space_sigma_x * space_sigma_x);
|
||||||
*spatial_gauss_y = -0.5f / (space_sigma_y * space_sigma_y);
|
*spatial_gauss_y = -0.5f / (space_sigma_y * space_sigma_y);
|
||||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -694,6 +694,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_score_thresholding_meta_json",
|
||||||
|
sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
|
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
|
||||||
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",
|
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user