Model Maker Gesture Recognizer: add metadata writer and create model bundle.

PiperOrigin-RevId: 485426865
This commit is contained in:
Yuqi Li 2022-11-01 15:00:21 -07:00 committed by Copybara-Service
parent c6a64683f6
commit e719f7b4e2
8 changed files with 194 additions and 9 deletions

View File

@ -197,6 +197,37 @@ class ScoreCalibrationMd:
self._FILE_TYPE)
class ScoreThresholdingMd:
"""A container for score thresholding [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
def __init__(self, global_score_threshold: float) -> None:
"""Creates a ScoreThresholdingMd object.
Args:
global_score_threshold: The recommended global threshold below which
results are considered low-confidence and should be filtered out.
"""
self._global_score_threshold = global_score_threshold
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the score thresholding metadata based on the information.
Returns:
A Flatbuffers Python object of the score thresholding metadata.
"""
score_thresholding = _metadata_fb.ProcessUnitT()
score_thresholding.optionsType = (
_metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions)
options = _metadata_fb.ScoreThresholdingOptionsT()
options.globalScoreThreshold = self._global_score_threshold
score_thresholding.options = options
return score_thresholding
class TensorMd:
"""A container for common tensor metadata information.
@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd):
tensor.
score_calibration_md: information of the score calibration operation [2] in
the classification tensor.
score_thresholding_md: information of the score thresholding [3] in the
classification tensor.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
# Min and max float values for classification results.
_MIN_FLOAT = 0.0
_MAX_FLOAT = 1.0
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None,
tensor_type: Optional[int] = None,
score_calibration_md: Optional[ScoreCalibrationMd] = None,
tensor_name: Optional[str] = None) -> None:
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None,
tensor_type: Optional[int] = None,
score_calibration_md: Optional[ScoreCalibrationMd] = None,
tensor_name: Optional[str] = None,
score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None:
"""Initializes the instance of ClassificationTensorMd.
Args:
@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd):
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
is used to locate the corresponding classification tensor and decide the
order of the tensor metadata [4] when populating model metadata.
score_thresholding_md: information of the score thresholding [5] in the
classification tensor.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd):
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
[5]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
self.score_calibration_md = score_calibration_md
self.score_thresholding_md = score_thresholding_md
if tensor_type is _schema_fb.TensorType.UINT8:
min_values = [_MIN_UINT8]
@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd):
tensor_metadata.processUnits = [
self.score_calibration_md.create_metadata()
]
if self.score_thresholding_md:
if tensor_metadata.processUnits:
tensor_metadata.processUnits.append(
self.score_thresholding_md.create_metadata())
else:
tensor_metadata.processUnits = [
self.score_thresholding_md.create_metadata()
]
return tensor_metadata

View File

@ -70,6 +70,18 @@ class LabelItem:
locale: Optional[str] = None
@dataclasses.dataclass
class ScoreThresholding:
"""Parameters to performs thresholding on output tensor values [1].
Attributes:
global_score_threshold: The recommended global threshold below which results
are considered low-confidence and should be filtered out. [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
global_score_threshold: float
class Labels(object):
"""Simple container holding classification labels of a particular tensor.
@ -407,6 +419,7 @@ class MetadataWriter(object):
self,
labels: Optional[Labels] = None,
score_calibration: Optional[ScoreCalibration] = None,
score_thresholding: Optional[ScoreThresholding] = None,
name: str = _OUTPUT_CLASSIFICATION_NAME,
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
) -> 'MetadataWriter':
@ -423,6 +436,7 @@ class MetadataWriter(object):
Args:
labels: an instance of Labels helper class.
score_calibration: an instance of ScoreCalibration helper class.
score_thresholding: an instance of ScoreThresholding.
name: Metadata name of the tensor. Note that this is different from tensor
name in the flatbuffer.
description: human readable description of what the output is.
@ -437,6 +451,10 @@ class MetadataWriter(object):
default_score=score_calibration.default_score,
file_path=self._export_calibration_file('score_calibration.txt',
score_calibration.parameters))
score_thresholding_md = None
if score_thresholding:
score_thresholding_md = metadata_info.ScoreThresholdingMd(
score_thresholding.global_score_threshold)
label_files = None
if labels:
@ -453,6 +471,7 @@ class MetadataWriter(object):
label_files=label_files,
tensor_type=self._output_tensor_type(len(self._output_mds)),
score_calibration_md=calibration_md,
score_thresholding_md=score_thresholding_md,
)
self._output_mds.append(output_md)
return self
@ -545,4 +564,3 @@ class MetadataWriterBase:
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
"""
return self.writer.populate()

View File

@ -14,7 +14,8 @@
# ==============================================================================
"""Helper methods for writing metadata into TFLite models."""
from typing import List
from typing import Dict, List
import zipfile
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
# multiple subgraphs yet, but models with mini-benchmark may have multiple
# subgraphs for acceleration evaluation purpose.
return model.Subgraphs(0)
def create_model_asset_bundle(input_models: Dict[str, bytes],
output_path: str) -> None:
"""Creates the model asset bundle.
Args:
input_models: A dict of input models with key as the model file name and
value as the model content.
output_path: The output file path to save the model asset bundle.
"""
if not input_models or len(input_models) < 2:
raise ValueError("Needs at least two input models for model asset bundle.")
with zipfile.ZipFile(output_path, mode="w") as zf:
for file_name, file_buffer in input_models.items():
zf.writestr(file_name, file_buffer)

View File

@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase):
malformed_calibration_file)
class ScoreThresholdingMdTest(absltest.TestCase):
_DEFAULT_GLOBAL_THRESHOLD = 0.5
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
"score_thresholding_meta.json")
def test_create_metadata_should_succeed(self):
score_thresholding_md = metadata_info.ScoreThresholdingMd(
global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD)
score_thresholding_metadata = score_thresholding_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(
score_thresholding_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata.

View File

@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase):
}
""")
def test_add_classification_output_with_score_thresholding(self):
writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer)
writer.add_classification_output(
labels=metadata_writer.Labels().add(['a', 'b', 'c']),
score_thresholding=metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
_, metadata_json = writer.populate()
print(metadata_json)
self.assertJsonEqual(
metadata_json, """{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input"
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}
""")
if __name__ == '__main__':
absltest.main()

View File

@ -50,6 +50,7 @@ exports_files([
"score_calibration.txt",
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",
"score_thresholding_meta.json",
"labels.txt",
"mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json",
@ -91,5 +92,6 @@ filegroup(
"score_calibration.txt",
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",
"score_thresholding_meta.json",
],
)

View File

@ -0,0 +1,14 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
]
}
]
}

View File

@ -694,6 +694,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
)
http_file(
name = "com_google_mediapipe_score_thresholding_meta_json",
sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19",
urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"],
)
http_file(
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",