Model Maker Gesture Recognizer: add metadata writer and create model bundle.

PiperOrigin-RevId: 485426865
This commit is contained in:
Yuqi Li 2022-11-01 15:00:21 -07:00 committed by Copybara-Service
parent c6a64683f6
commit e719f7b4e2
8 changed files with 194 additions and 9 deletions

View File

@ -197,6 +197,37 @@ class ScoreCalibrationMd:
self._FILE_TYPE) self._FILE_TYPE)
class ScoreThresholdingMd:
"""A container for score thresholding [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
def __init__(self, global_score_threshold: float) -> None:
"""Creates a ScoreThresholdingMd object.
Args:
global_score_threshold: The recommended global threshold below which
results are considered low-confidence and should be filtered out.
"""
self._global_score_threshold = global_score_threshold
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the score thresholding metadata based on the information.
Returns:
A Flatbuffers Python object of the score thresholding metadata.
"""
score_thresholding = _metadata_fb.ProcessUnitT()
score_thresholding.optionsType = (
_metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions)
options = _metadata_fb.ScoreThresholdingOptionsT()
options.globalScoreThreshold = self._global_score_threshold
score_thresholding.options = options
return score_thresholding
class TensorMd: class TensorMd:
"""A container for common tensor metadata information. """A container for common tensor metadata information.
@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd):
tensor. tensor.
score_calibration_md: information of the score calibration operation [2] in score_calibration_md: information of the score calibration operation [2] in
the classification tensor. the classification tensor.
score_thresholding_md: information of the score thresholding [3] in the
classification tensor.
[1]: [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]: [2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
""" """
# Min and max float values for classification results. # Min and max float values for classification results.
_MIN_FLOAT = 0.0 _MIN_FLOAT = 0.0
_MAX_FLOAT = 1.0 _MAX_FLOAT = 1.0
def __init__(self, def __init__(
self,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None, label_files: Optional[List[LabelFileMd]] = None,
tensor_type: Optional[int] = None, tensor_type: Optional[int] = None,
score_calibration_md: Optional[ScoreCalibrationMd] = None, score_calibration_md: Optional[ScoreCalibrationMd] = None,
tensor_name: Optional[str] = None) -> None: tensor_name: Optional[str] = None,
score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None:
"""Initializes the instance of ClassificationTensorMd. """Initializes the instance of ClassificationTensorMd.
Args: Args:
@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd):
tensor_name: name of the corresponding tensor [3] in the TFLite model. It tensor_name: name of the corresponding tensor [3] in the TFLite model. It
is used to locate the corresponding classification tensor and decide the is used to locate the corresponding classification tensor and decide the
order of the tensor metadata [4] when populating model metadata. order of the tensor metadata [4] when populating model metadata.
score_thresholding_md: information of the score thresholding [5] in the
classification tensor.
[1]: [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]: [2]:
@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd):
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[4]: [4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
[5]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
""" """
self.score_calibration_md = score_calibration_md self.score_calibration_md = score_calibration_md
self.score_thresholding_md = score_thresholding_md
if tensor_type is _schema_fb.TensorType.UINT8: if tensor_type is _schema_fb.TensorType.UINT8:
min_values = [_MIN_UINT8] min_values = [_MIN_UINT8]
@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd):
tensor_metadata.processUnits = [ tensor_metadata.processUnits = [
self.score_calibration_md.create_metadata() self.score_calibration_md.create_metadata()
] ]
if self.score_thresholding_md:
if tensor_metadata.processUnits:
tensor_metadata.processUnits.append(
self.score_thresholding_md.create_metadata())
else:
tensor_metadata.processUnits = [
self.score_thresholding_md.create_metadata()
]
return tensor_metadata return tensor_metadata

View File

@ -70,6 +70,18 @@ class LabelItem:
locale: Optional[str] = None locale: Optional[str] = None
@dataclasses.dataclass
class ScoreThresholding:
"""Parameters to performs thresholding on output tensor values [1].
Attributes:
global_score_threshold: The recommended global threshold below which results
are considered low-confidence and should be filtered out. [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
global_score_threshold: float
class Labels(object): class Labels(object):
"""Simple container holding classification labels of a particular tensor. """Simple container holding classification labels of a particular tensor.
@ -407,6 +419,7 @@ class MetadataWriter(object):
self, self,
labels: Optional[Labels] = None, labels: Optional[Labels] = None,
score_calibration: Optional[ScoreCalibration] = None, score_calibration: Optional[ScoreCalibration] = None,
score_thresholding: Optional[ScoreThresholding] = None,
name: str = _OUTPUT_CLASSIFICATION_NAME, name: str = _OUTPUT_CLASSIFICATION_NAME,
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
) -> 'MetadataWriter': ) -> 'MetadataWriter':
@ -423,6 +436,7 @@ class MetadataWriter(object):
Args: Args:
labels: an instance of Labels helper class. labels: an instance of Labels helper class.
score_calibration: an instance of ScoreCalibration helper class. score_calibration: an instance of ScoreCalibration helper class.
score_thresholding: an instance of ScoreThresholding.
name: Metadata name of the tensor. Note that this is different from tensor name: Metadata name of the tensor. Note that this is different from tensor
name in the flatbuffer. name in the flatbuffer.
description: human readable description of what the output is. description: human readable description of what the output is.
@ -437,6 +451,10 @@ class MetadataWriter(object):
default_score=score_calibration.default_score, default_score=score_calibration.default_score,
file_path=self._export_calibration_file('score_calibration.txt', file_path=self._export_calibration_file('score_calibration.txt',
score_calibration.parameters)) score_calibration.parameters))
score_thresholding_md = None
if score_thresholding:
score_thresholding_md = metadata_info.ScoreThresholdingMd(
score_thresholding.global_score_threshold)
label_files = None label_files = None
if labels: if labels:
@ -453,6 +471,7 @@ class MetadataWriter(object):
label_files=label_files, label_files=label_files,
tensor_type=self._output_tensor_type(len(self._output_mds)), tensor_type=self._output_tensor_type(len(self._output_mds)),
score_calibration_md=calibration_md, score_calibration_md=calibration_md,
score_thresholding_md=score_thresholding_md,
) )
self._output_mds.append(output_md) self._output_mds.append(output_md)
return self return self
@ -545,4 +564,3 @@ class MetadataWriterBase:
A tuple of (model_with_metadata_in_bytes, metdata_json_content) A tuple of (model_with_metadata_in_bytes, metdata_json_content)
""" """
return self.writer.populate() return self.writer.populate()

View File

@ -14,7 +14,8 @@
# ============================================================================== # ==============================================================================
"""Helper methods for writing metadata into TFLite models.""" """Helper methods for writing metadata into TFLite models."""
from typing import List from typing import Dict, List
import zipfile
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
# multiple subgraphs yet, but models with mini-benchmark may have multiple # multiple subgraphs yet, but models with mini-benchmark may have multiple
# subgraphs for acceleration evaluation purpose. # subgraphs for acceleration evaluation purpose.
return model.Subgraphs(0) return model.Subgraphs(0)
def create_model_asset_bundle(input_models: Dict[str, bytes],
output_path: str) -> None:
"""Creates the model asset bundle.
Args:
input_models: A dict of input models with key as the model file name and
value as the model content.
output_path: The output file path to save the model asset bundle.
"""
if not input_models or len(input_models) < 2:
raise ValueError("Needs at least two input models for model asset bundle.")
with zipfile.ZipFile(output_path, mode="w") as zf:
for file_name, file_buffer in input_models.items():
zf.writestr(file_name, file_buffer)

View File

@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase):
malformed_calibration_file) malformed_calibration_file)
class ScoreThresholdingMdTest(absltest.TestCase):
_DEFAULT_GLOBAL_THRESHOLD = 0.5
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
"score_thresholding_meta.json")
def test_create_metadata_should_succeed(self):
score_thresholding_md = metadata_info.ScoreThresholdingMd(
global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD)
score_thresholding_metadata = score_thresholding_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(
score_thresholding_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def _create_dummy_model_metadata_with_tensor( def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata. # Create a dummy model using the tensor metadata.

View File

@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase):
} }
""") """)
def test_add_classification_output_with_score_thresholding(self):
writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer)
writer.add_classification_output(
labels=metadata_writer.Labels().add(['a', 'b', 'c']),
score_thresholding=metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
_, metadata_json = writer.populate()
print(metadata_json)
self.assertJsonEqual(
metadata_json, """{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input"
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}
""")
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -50,6 +50,7 @@ exports_files([
"score_calibration.txt", "score_calibration.txt",
"score_calibration_file_meta.json", "score_calibration_file_meta.json",
"score_calibration_tensor_meta.json", "score_calibration_tensor_meta.json",
"score_thresholding_meta.json",
"labels.txt", "labels.txt",
"mobilenet_v2_1.0_224.json", "mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json", "mobilenet_v2_1.0_224_quant.json",
@ -91,5 +92,6 @@ filegroup(
"score_calibration.txt", "score_calibration.txt",
"score_calibration_file_meta.json", "score_calibration_file_meta.json",
"score_calibration_tensor_meta.json", "score_calibration_tensor_meta.json",
"score_thresholding_meta.json",
], ],
) )

View File

@ -0,0 +1,14 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
]
}
]
}

View File

@ -694,6 +694,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"], urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
) )
http_file(
name = "com_google_mediapipe_score_thresholding_meta_json",
sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19",
urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"],
)
http_file( http_file(
name = "com_google_mediapipe_segmentation_golden_rotation0_png", name = "com_google_mediapipe_segmentation_golden_rotation0_png",
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069", sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",