Model Maker Gesture Recognizer: add metadata writer and create model bundle.
PiperOrigin-RevId: 485426865
This commit is contained in:
parent
c6a64683f6
commit
e719f7b4e2
|
@ -197,6 +197,37 @@ class ScoreCalibrationMd:
|
|||
self._FILE_TYPE)
|
||||
|
||||
|
||||
class ScoreThresholdingMd:
|
||||
"""A container for score thresholding [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
|
||||
def __init__(self, global_score_threshold: float) -> None:
|
||||
"""Creates a ScoreThresholdingMd object.
|
||||
|
||||
Args:
|
||||
global_score_threshold: The recommended global threshold below which
|
||||
results are considered low-confidence and should be filtered out.
|
||||
"""
|
||||
self._global_score_threshold = global_score_threshold
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||
"""Creates the score thresholding metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the score thresholding metadata.
|
||||
"""
|
||||
score_thresholding = _metadata_fb.ProcessUnitT()
|
||||
score_thresholding.optionsType = (
|
||||
_metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions)
|
||||
options = _metadata_fb.ScoreThresholdingOptionsT()
|
||||
options.globalScoreThreshold = self._global_score_threshold
|
||||
score_thresholding.options = options
|
||||
return score_thresholding
|
||||
|
||||
|
||||
class TensorMd:
|
||||
"""A container for common tensor metadata information.
|
||||
|
||||
|
@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd):
|
|||
tensor.
|
||||
score_calibration_md: information of the score calibration operation [2] in
|
||||
the classification tensor.
|
||||
score_thresholding_md: information of the score thresholding [3] in the
|
||||
classification tensor.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
|
||||
# Min and max float values for classification results.
|
||||
_MIN_FLOAT = 0.0
|
||||
_MAX_FLOAT = 1.0
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
label_files: Optional[List[LabelFileMd]] = None,
|
||||
tensor_type: Optional[int] = None,
|
||||
score_calibration_md: Optional[ScoreCalibrationMd] = None,
|
||||
tensor_name: Optional[str] = None) -> None:
|
||||
tensor_name: Optional[str] = None,
|
||||
score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None:
|
||||
"""Initializes the instance of ClassificationTensorMd.
|
||||
|
||||
Args:
|
||||
|
@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd):
|
|||
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
|
||||
is used to locate the corresponding classification tensor and decide the
|
||||
order of the tensor metadata [4] when populating model metadata.
|
||||
score_thresholding_md: information of the score thresholding [5] in the
|
||||
classification tensor.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[2]:
|
||||
|
@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd):
|
|||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||
[4]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||
[5]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
self.score_calibration_md = score_calibration_md
|
||||
self.score_thresholding_md = score_thresholding_md
|
||||
|
||||
if tensor_type is _schema_fb.TensorType.UINT8:
|
||||
min_values = [_MIN_UINT8]
|
||||
|
@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd):
|
|||
tensor_metadata.processUnits = [
|
||||
self.score_calibration_md.create_metadata()
|
||||
]
|
||||
if self.score_thresholding_md:
|
||||
if tensor_metadata.processUnits:
|
||||
tensor_metadata.processUnits.append(
|
||||
self.score_thresholding_md.create_metadata())
|
||||
else:
|
||||
tensor_metadata.processUnits = [
|
||||
self.score_thresholding_md.create_metadata()
|
||||
]
|
||||
return tensor_metadata
|
||||
|
|
|
@ -70,6 +70,18 @@ class LabelItem:
|
|||
locale: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScoreThresholding:
|
||||
"""Parameters to performs thresholding on output tensor values [1].
|
||||
|
||||
Attributes:
|
||||
global_score_threshold: The recommended global threshold below which results
|
||||
are considered low-confidence and should be filtered out. [1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
global_score_threshold: float
|
||||
|
||||
|
||||
class Labels(object):
|
||||
"""Simple container holding classification labels of a particular tensor.
|
||||
|
||||
|
@ -407,6 +419,7 @@ class MetadataWriter(object):
|
|||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
score_calibration: Optional[ScoreCalibration] = None,
|
||||
score_thresholding: Optional[ScoreThresholding] = None,
|
||||
name: str = _OUTPUT_CLASSIFICATION_NAME,
|
||||
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
|
||||
) -> 'MetadataWriter':
|
||||
|
@ -423,6 +436,7 @@ class MetadataWriter(object):
|
|||
Args:
|
||||
labels: an instance of Labels helper class.
|
||||
score_calibration: an instance of ScoreCalibration helper class.
|
||||
score_thresholding: an instance of ScoreThresholding.
|
||||
name: Metadata name of the tensor. Note that this is different from tensor
|
||||
name in the flatbuffer.
|
||||
description: human readable description of what the output is.
|
||||
|
@ -437,6 +451,10 @@ class MetadataWriter(object):
|
|||
default_score=score_calibration.default_score,
|
||||
file_path=self._export_calibration_file('score_calibration.txt',
|
||||
score_calibration.parameters))
|
||||
score_thresholding_md = None
|
||||
if score_thresholding:
|
||||
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||
score_thresholding.global_score_threshold)
|
||||
|
||||
label_files = None
|
||||
if labels:
|
||||
|
@ -453,6 +471,7 @@ class MetadataWriter(object):
|
|||
label_files=label_files,
|
||||
tensor_type=self._output_tensor_type(len(self._output_mds)),
|
||||
score_calibration_md=calibration_md,
|
||||
score_thresholding_md=score_thresholding_md,
|
||||
)
|
||||
self._output_mds.append(output_md)
|
||||
return self
|
||||
|
@ -545,4 +564,3 @@ class MetadataWriterBase:
|
|||
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||
"""
|
||||
return self.writer.populate()
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# ==============================================================================
|
||||
"""Helper methods for writing metadata into TFLite models."""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
import zipfile
|
||||
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
|
||||
|
@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
|
|||
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
||||
# subgraphs for acceleration evaluation purpose.
|
||||
return model.Subgraphs(0)
|
||||
|
||||
|
||||
def create_model_asset_bundle(input_models: Dict[str, bytes],
|
||||
output_path: str) -> None:
|
||||
"""Creates the model asset bundle.
|
||||
|
||||
Args:
|
||||
input_models: A dict of input models with key as the model file name and
|
||||
value as the model content.
|
||||
output_path: The output file path to save the model asset bundle.
|
||||
"""
|
||||
if not input_models or len(input_models) < 2:
|
||||
raise ValueError("Needs at least two input models for model asset bundle.")
|
||||
|
||||
with zipfile.ZipFile(output_path, mode="w") as zf:
|
||||
for file_name, file_buffer in input_models.items():
|
||||
zf.writestr(file_name, file_buffer)
|
||||
|
|
|
@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase):
|
|||
malformed_calibration_file)
|
||||
|
||||
|
||||
class ScoreThresholdingMdTest(absltest.TestCase):
|
||||
_DEFAULT_GLOBAL_THRESHOLD = 0.5
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"score_thresholding_meta.json")
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||
global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD)
|
||||
|
||||
score_thresholding_metadata = score_thresholding_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_process_uint(
|
||||
score_thresholding_metadata))
|
||||
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
def _create_dummy_model_metadata_with_tensor(
|
||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||
# Create a dummy model using the tensor metadata.
|
||||
|
|
|
@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
}
|
||||
""")
|
||||
|
||||
def test_add_classification_output_with_score_thresholding(self):
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
self.image_classifier_model_buffer)
|
||||
writer.add_classification_output(
|
||||
labels=metadata_writer.Labels().add(['a', 'b', 'c']),
|
||||
score_thresholding=metadata_writer.ScoreThresholding(
|
||||
global_score_threshold=0.5))
|
||||
_, metadata_json = writer.populate()
|
||||
print(metadata_json)
|
||||
self.assertJsonEqual(
|
||||
metadata_json, """{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "input"
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "ScoreThresholdingOptions",
|
||||
"options": {
|
||||
"global_score_threshold": 0.5
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -50,6 +50,7 @@ exports_files([
|
|||
"score_calibration.txt",
|
||||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
"labels.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
|
@ -91,5 +92,6 @@ filegroup(
|
|||
"score_calibration.txt",
|
||||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
],
|
||||
)
|
||||
|
|
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "ScoreThresholdingOptions",
|
||||
"options": {
|
||||
"global_score_threshold": 0.5
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -694,6 +694,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_score_thresholding_meta_json",
|
||||
sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
|
||||
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",
|
||||
|
|
Loading…
Reference in New Issue
Block a user