Model Maker Gesture Recognizer: add metadata writer and create model bundle.
PiperOrigin-RevId: 485426865
This commit is contained in:
parent
c6a64683f6
commit
e719f7b4e2
|
@ -197,6 +197,37 @@ class ScoreCalibrationMd:
|
||||||
self._FILE_TYPE)
|
self._FILE_TYPE)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreThresholdingMd:
|
||||||
|
"""A container for score thresholding [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, global_score_threshold: float) -> None:
|
||||||
|
"""Creates a ScoreThresholdingMd object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_score_threshold: The recommended global threshold below which
|
||||||
|
results are considered low-confidence and should be filtered out.
|
||||||
|
"""
|
||||||
|
self._global_score_threshold = global_score_threshold
|
||||||
|
|
||||||
|
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||||
|
"""Creates the score thresholding metadata based on the information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Flatbuffers Python object of the score thresholding metadata.
|
||||||
|
"""
|
||||||
|
score_thresholding = _metadata_fb.ProcessUnitT()
|
||||||
|
score_thresholding.optionsType = (
|
||||||
|
_metadata_fb.ProcessUnitOptions.ScoreThresholdingOptions)
|
||||||
|
options = _metadata_fb.ScoreThresholdingOptionsT()
|
||||||
|
options.globalScoreThreshold = self._global_score_threshold
|
||||||
|
score_thresholding.options = options
|
||||||
|
return score_thresholding
|
||||||
|
|
||||||
|
|
||||||
class TensorMd:
|
class TensorMd:
|
||||||
"""A container for common tensor metadata information.
|
"""A container for common tensor metadata information.
|
||||||
|
|
||||||
|
@ -374,23 +405,29 @@ class ClassificationTensorMd(TensorMd):
|
||||||
tensor.
|
tensor.
|
||||||
score_calibration_md: information of the score calibration operation [2] in
|
score_calibration_md: information of the score calibration operation [2] in
|
||||||
the classification tensor.
|
the classification tensor.
|
||||||
|
score_thresholding_md: information of the score thresholding [3] in the
|
||||||
|
classification tensor.
|
||||||
[1]:
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
[2]:
|
[2]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||||
|
[3]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Min and max float values for classification results.
|
# Min and max float values for classification results.
|
||||||
_MIN_FLOAT = 0.0
|
_MIN_FLOAT = 0.0
|
||||||
_MAX_FLOAT = 1.0
|
_MAX_FLOAT = 1.0
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
name: Optional[str] = None,
|
self,
|
||||||
description: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
label_files: Optional[List[LabelFileMd]] = None,
|
description: Optional[str] = None,
|
||||||
tensor_type: Optional[int] = None,
|
label_files: Optional[List[LabelFileMd]] = None,
|
||||||
score_calibration_md: Optional[ScoreCalibrationMd] = None,
|
tensor_type: Optional[int] = None,
|
||||||
tensor_name: Optional[str] = None) -> None:
|
score_calibration_md: Optional[ScoreCalibrationMd] = None,
|
||||||
|
tensor_name: Optional[str] = None,
|
||||||
|
score_thresholding_md: Optional[ScoreThresholdingMd] = None) -> None:
|
||||||
"""Initializes the instance of ClassificationTensorMd.
|
"""Initializes the instance of ClassificationTensorMd.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -404,6 +441,8 @@ class ClassificationTensorMd(TensorMd):
|
||||||
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
|
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
|
||||||
is used to locate the corresponding classification tensor and decide the
|
is used to locate the corresponding classification tensor and decide the
|
||||||
order of the tensor metadata [4] when populating model metadata.
|
order of the tensor metadata [4] when populating model metadata.
|
||||||
|
score_thresholding_md: information of the score thresholding [5] in the
|
||||||
|
classification tensor.
|
||||||
[1]:
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
[2]:
|
[2]:
|
||||||
|
@ -412,8 +451,11 @@ class ClassificationTensorMd(TensorMd):
|
||||||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||||
[4]:
|
[4]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||||
|
[5]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
"""
|
"""
|
||||||
self.score_calibration_md = score_calibration_md
|
self.score_calibration_md = score_calibration_md
|
||||||
|
self.score_thresholding_md = score_thresholding_md
|
||||||
|
|
||||||
if tensor_type is _schema_fb.TensorType.UINT8:
|
if tensor_type is _schema_fb.TensorType.UINT8:
|
||||||
min_values = [_MIN_UINT8]
|
min_values = [_MIN_UINT8]
|
||||||
|
@ -443,4 +485,12 @@ class ClassificationTensorMd(TensorMd):
|
||||||
tensor_metadata.processUnits = [
|
tensor_metadata.processUnits = [
|
||||||
self.score_calibration_md.create_metadata()
|
self.score_calibration_md.create_metadata()
|
||||||
]
|
]
|
||||||
|
if self.score_thresholding_md:
|
||||||
|
if tensor_metadata.processUnits:
|
||||||
|
tensor_metadata.processUnits.append(
|
||||||
|
self.score_thresholding_md.create_metadata())
|
||||||
|
else:
|
||||||
|
tensor_metadata.processUnits = [
|
||||||
|
self.score_thresholding_md.create_metadata()
|
||||||
|
]
|
||||||
return tensor_metadata
|
return tensor_metadata
|
||||||
|
|
|
@ -70,6 +70,18 @@ class LabelItem:
|
||||||
locale: Optional[str] = None
|
locale: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ScoreThresholding:
|
||||||
|
"""Parameters to performs thresholding on output tensor values [1].
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
global_score_threshold: The recommended global threshold below which results
|
||||||
|
are considered low-confidence and should be filtered out. [1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
|
"""
|
||||||
|
global_score_threshold: float
|
||||||
|
|
||||||
|
|
||||||
class Labels(object):
|
class Labels(object):
|
||||||
"""Simple container holding classification labels of a particular tensor.
|
"""Simple container holding classification labels of a particular tensor.
|
||||||
|
|
||||||
|
@ -407,6 +419,7 @@ class MetadataWriter(object):
|
||||||
self,
|
self,
|
||||||
labels: Optional[Labels] = None,
|
labels: Optional[Labels] = None,
|
||||||
score_calibration: Optional[ScoreCalibration] = None,
|
score_calibration: Optional[ScoreCalibration] = None,
|
||||||
|
score_thresholding: Optional[ScoreThresholding] = None,
|
||||||
name: str = _OUTPUT_CLASSIFICATION_NAME,
|
name: str = _OUTPUT_CLASSIFICATION_NAME,
|
||||||
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
|
description: str = _OUTPUT_CLASSIFICATION_DESCRIPTION
|
||||||
) -> 'MetadataWriter':
|
) -> 'MetadataWriter':
|
||||||
|
@ -423,6 +436,7 @@ class MetadataWriter(object):
|
||||||
Args:
|
Args:
|
||||||
labels: an instance of Labels helper class.
|
labels: an instance of Labels helper class.
|
||||||
score_calibration: an instance of ScoreCalibration helper class.
|
score_calibration: an instance of ScoreCalibration helper class.
|
||||||
|
score_thresholding: an instance of ScoreThresholding.
|
||||||
name: Metadata name of the tensor. Note that this is different from tensor
|
name: Metadata name of the tensor. Note that this is different from tensor
|
||||||
name in the flatbuffer.
|
name in the flatbuffer.
|
||||||
description: human readable description of what the output is.
|
description: human readable description of what the output is.
|
||||||
|
@ -437,6 +451,10 @@ class MetadataWriter(object):
|
||||||
default_score=score_calibration.default_score,
|
default_score=score_calibration.default_score,
|
||||||
file_path=self._export_calibration_file('score_calibration.txt',
|
file_path=self._export_calibration_file('score_calibration.txt',
|
||||||
score_calibration.parameters))
|
score_calibration.parameters))
|
||||||
|
score_thresholding_md = None
|
||||||
|
if score_thresholding:
|
||||||
|
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||||
|
score_thresholding.global_score_threshold)
|
||||||
|
|
||||||
label_files = None
|
label_files = None
|
||||||
if labels:
|
if labels:
|
||||||
|
@ -453,6 +471,7 @@ class MetadataWriter(object):
|
||||||
label_files=label_files,
|
label_files=label_files,
|
||||||
tensor_type=self._output_tensor_type(len(self._output_mds)),
|
tensor_type=self._output_tensor_type(len(self._output_mds)),
|
||||||
score_calibration_md=calibration_md,
|
score_calibration_md=calibration_md,
|
||||||
|
score_thresholding_md=score_thresholding_md,
|
||||||
)
|
)
|
||||||
self._output_mds.append(output_md)
|
self._output_mds.append(output_md)
|
||||||
return self
|
return self
|
||||||
|
@ -545,4 +564,3 @@ class MetadataWriterBase:
|
||||||
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||||
"""
|
"""
|
||||||
return self.writer.populate()
|
return self.writer.populate()
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Helper methods for writing metadata into TFLite models."""
|
"""Helper methods for writing metadata into TFLite models."""
|
||||||
|
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
import zipfile
|
||||||
|
|
||||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||||
|
|
||||||
|
@ -83,3 +84,20 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
|
||||||
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
||||||
# subgraphs for acceleration evaluation purpose.
|
# subgraphs for acceleration evaluation purpose.
|
||||||
return model.Subgraphs(0)
|
return model.Subgraphs(0)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_asset_bundle(input_models: Dict[str, bytes],
|
||||||
|
output_path: str) -> None:
|
||||||
|
"""Creates the model asset bundle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_models: A dict of input models with key as the model file name and
|
||||||
|
value as the model content.
|
||||||
|
output_path: The output file path to save the model asset bundle.
|
||||||
|
"""
|
||||||
|
if not input_models or len(input_models) < 2:
|
||||||
|
raise ValueError("Needs at least two input models for model asset bundle.")
|
||||||
|
|
||||||
|
with zipfile.ZipFile(output_path, mode="w") as zf:
|
||||||
|
for file_name, file_buffer in input_models.items():
|
||||||
|
zf.writestr(file_name, file_buffer)
|
||||||
|
|
|
@ -307,6 +307,25 @@ class ScoreCalibrationMdTest(absltest.TestCase):
|
||||||
malformed_calibration_file)
|
malformed_calibration_file)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreThresholdingMdTest(absltest.TestCase):
|
||||||
|
_DEFAULT_GLOBAL_THRESHOLD = 0.5
|
||||||
|
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
|
"score_thresholding_meta.json")
|
||||||
|
|
||||||
|
def test_create_metadata_should_succeed(self):
|
||||||
|
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||||
|
global_score_threshold=self._DEFAULT_GLOBAL_THRESHOLD)
|
||||||
|
|
||||||
|
score_thresholding_metadata = score_thresholding_md.create_metadata()
|
||||||
|
|
||||||
|
metadata_json = _metadata.convert_to_json(
|
||||||
|
_create_dummy_model_metadata_with_process_uint(
|
||||||
|
score_thresholding_metadata))
|
||||||
|
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
||||||
def _create_dummy_model_metadata_with_tensor(
|
def _create_dummy_model_metadata_with_tensor(
|
||||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||||
# Create a dummy model using the tensor metadata.
|
# Create a dummy model using the tensor metadata.
|
||||||
|
|
|
@ -405,6 +405,64 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
def test_add_classification_output_with_score_thresholding(self):
|
||||||
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
|
self.image_classifier_model_buffer)
|
||||||
|
writer.add_classification_output(
|
||||||
|
labels=metadata_writer.Labels().add(['a', 'b', 'c']),
|
||||||
|
score_thresholding=metadata_writer.ScoreThresholding(
|
||||||
|
global_score_threshold=0.5))
|
||||||
|
_, metadata_json = writer.populate()
|
||||||
|
print(metadata_json)
|
||||||
|
self.assertJsonEqual(
|
||||||
|
metadata_json, """{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "input"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "ScoreThresholdingOptions",
|
||||||
|
"options": {
|
||||||
|
"global_score_threshold": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.0.0"
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
2
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -50,6 +50,7 @@ exports_files([
|
||||||
"score_calibration.txt",
|
"score_calibration.txt",
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
|
"score_thresholding_meta.json",
|
||||||
"labels.txt",
|
"labels.txt",
|
||||||
"mobilenet_v2_1.0_224.json",
|
"mobilenet_v2_1.0_224.json",
|
||||||
"mobilenet_v2_1.0_224_quant.json",
|
"mobilenet_v2_1.0_224_quant.json",
|
||||||
|
@ -91,5 +92,6 @@ filegroup(
|
||||||
"score_calibration.txt",
|
"score_calibration.txt",
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
|
"score_thresholding_meta.json",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
14
mediapipe/tasks/testdata/metadata/score_thresholding_meta.json
vendored
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "ScoreThresholdingOptions",
|
||||||
|
"options": {
|
||||||
|
"global_score_threshold": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -694,6 +694,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_score_thresholding_meta_json",
|
||||||
|
sha256 = "7bb74f21c2d7f0237675ed7c09d7b7afd3507c8373f51dc75fa0507852f6ee19",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/score_thresholding_meta.json?generation=1667273953630766"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
|
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
|
||||||
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",
|
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user