Metadata Writer: Add Metadata Writer for image classifier.
PiperOrigin-RevId: 483282627
This commit is contained in:
parent
ec2a34d2a4
commit
ab17be9294
|
@ -37,3 +37,9 @@ py_library(
|
||||||
srcs = ["writer_utils.py"],
|
srcs = ["writer_utils.py"],
|
||||||
deps = ["//mediapipe/tasks/metadata:schema_py"],
|
deps = ["//mediapipe/tasks/metadata:schema_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_classifier",
|
||||||
|
srcs = ["image_classifier.py"],
|
||||||
|
deps = [":metadata_writer"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Writes metadata and label file to the image classifier models."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
|
||||||
|
_MODEL_NAME = "ImageClassifier"
|
||||||
|
_MODEL_DESCRIPTION = ("Identify the most prominent object in the image from a "
|
||||||
|
"known set of categories.")
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||||
|
"""MetadataWriter to write the metadata for image classifier."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
model_buffer: bytearray,
|
||||||
|
input_norm_mean: List[float],
|
||||||
|
input_norm_std: List[float],
|
||||||
|
labels: metadata_writer.Labels,
|
||||||
|
score_calibration: Optional[metadata_writer.ScoreCalibration] = None
|
||||||
|
) -> "MetadataWriter":
|
||||||
|
"""Creates MetadataWriter to write the metadata for image classifier.
|
||||||
|
|
||||||
|
The parameters required in this method are mandatory when using MediaPipe
|
||||||
|
Tasks.
|
||||||
|
|
||||||
|
Note that only the output TFLite is used for deployment. The output JSON
|
||||||
|
content is used to interpret the metadata content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||||
|
input_norm_mean: the mean value used in the input tensor normalization
|
||||||
|
[1].
|
||||||
|
input_norm_std: the std value used in the input tensor normalizarion [1].
|
||||||
|
labels: an instance of Labels helper class used in the output
|
||||||
|
classification tensor [2].
|
||||||
|
score_calibration: A container of the score calibration operation [3] in
|
||||||
|
the classification tensor. Optional if the model does not use score
|
||||||
|
calibration.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||||
|
[2]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
|
[3]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An MetadataWrite object.
|
||||||
|
"""
|
||||||
|
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||||
|
writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||||
|
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||||
|
writer.add_classification_output(labels, score_calibration)
|
||||||
|
return cls(writer)
|
|
@ -15,19 +15,22 @@
|
||||||
"""Generic metadata writer."""
|
"""Generic metadata writer."""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import csv
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import flatbuffers
|
import flatbuffers
|
||||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||||
from mediapipe.tasks.python.metadata import metadata as _metadata
|
from mediapipe.tasks.python.metadata import metadata
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||||
|
|
||||||
_INPUT_IMAGE_NAME = 'image'
|
_INPUT_IMAGE_NAME = 'image'
|
||||||
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
|
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
|
||||||
|
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
||||||
|
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -140,26 +143,85 @@ class Labels(object):
|
||||||
class ScoreCalibration:
|
class ScoreCalibration:
|
||||||
"""Simple container holding score calibration related parameters."""
|
"""Simple container holding score calibration related parameters."""
|
||||||
|
|
||||||
# A shortcut to avoid client side code importing _metadata_fb
|
# A shortcut to avoid client side code importing metadata_fb
|
||||||
transformation_types = _metadata_fb.ScoreTransformationType
|
transformation_types = metadata_fb.ScoreTransformationType
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
transformation_type: _metadata_fb.ScoreTransformationType,
|
transformation_type: metadata_fb.ScoreTransformationType,
|
||||||
parameters: List[CalibrationParameter],
|
parameters: List[Optional[CalibrationParameter]],
|
||||||
default_score: int = 0):
|
default_score: int = 0):
|
||||||
self.transformation_type = transformation_type
|
self.transformation_type = transformation_type
|
||||||
self.parameters = parameters
|
self.parameters = parameters
|
||||||
self.default_score = default_score
|
self.default_score = default_score
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_file(cls,
|
||||||
|
transformation_type: metadata_fb.ScoreTransformationType,
|
||||||
|
file_path: str,
|
||||||
|
default_score: int = 0) -> 'ScoreCalibration':
|
||||||
|
"""Creates ScoreCalibration from the file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transformation_type: type of the function used for transforming the
|
||||||
|
uncalibrated score before applying score calibration.
|
||||||
|
file_path: file_path of the score calibration file [1]. Contains
|
||||||
|
sigmoid-based score calibration parameters, formatted as CSV. Lines
|
||||||
|
contain for each index of an output tensor the scale, slope, offset and
|
||||||
|
(optional) min_score parameters to be used for sigmoid fitting (in this
|
||||||
|
order and in `strtof`-compatible [2] format). Scale should be a
|
||||||
|
non-negative value. A line may be left empty to default calibrated
|
||||||
|
scores for this index to default_score. In summary, each line should
|
||||||
|
thus contain 0, 3 or 4 comma-separated values.
|
||||||
|
default_score: the default calibrated score to apply if the uncalibrated
|
||||||
|
score is below min_score or if no parameters were specified for a given
|
||||||
|
index.
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L133
|
||||||
|
[2]:
|
||||||
|
https://en.cppreference.com/w/c/string/byte/strtof
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ScoreCalibration object.
|
||||||
|
Raises:
|
||||||
|
ValueError: if the score_calibration file is malformed.
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r') as calibration_file:
|
||||||
|
csv_reader = csv.reader(calibration_file, delimiter=',')
|
||||||
|
parameters = []
|
||||||
|
for row in csv_reader:
|
||||||
|
if not row:
|
||||||
|
parameters.append(None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(row) != 3 and len(row) != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f'Expected empty lines or 3 or 4 parameters per line in score'
|
||||||
|
f' calibration file, but got {len(row)}.')
|
||||||
|
|
||||||
|
if float(row[0]) < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f'Expected scale to be a non-negative value, but got '
|
||||||
|
f'{float(row[0])}.')
|
||||||
|
|
||||||
|
parameters.append(
|
||||||
|
CalibrationParameter(
|
||||||
|
scale=float(row[0]),
|
||||||
|
slope=float(row[1]),
|
||||||
|
offset=float(row[2]),
|
||||||
|
min_score=None if len(row) == 3 else float(row[3])))
|
||||||
|
|
||||||
|
return cls(transformation_type, parameters, default_score)
|
||||||
|
|
||||||
|
|
||||||
def _fill_default_tensor_names(
|
def _fill_default_tensor_names(
|
||||||
tensor_metadata: List[_metadata_fb.TensorMetadataT],
|
tensor_metadata_list: List[metadata_fb.TensorMetadataT],
|
||||||
tensor_names_from_model: List[str]):
|
tensor_names_from_model: List[str]):
|
||||||
"""Fills the default tensor names."""
|
"""Fills the default tensor names."""
|
||||||
# If tensor name in metadata is empty, default to the tensor name saved in
|
# If tensor name in metadata is empty, default to the tensor name saved in
|
||||||
# the model.
|
# the model.
|
||||||
for metadata, name in zip(tensor_metadata, tensor_names_from_model):
|
for tensor_metadata, name in zip(tensor_metadata_list,
|
||||||
metadata.name = metadata.name or name
|
tensor_names_from_model):
|
||||||
|
tensor_metadata.name = tensor_metadata.name or name
|
||||||
|
|
||||||
|
|
||||||
def _pair_tensor_metadata(
|
def _pair_tensor_metadata(
|
||||||
|
@ -212,7 +274,7 @@ def _create_metadata_buffer(
|
||||||
input_metadata = [m.create_metadata() for m in input_md]
|
input_metadata = [m.create_metadata() for m in input_md]
|
||||||
else:
|
else:
|
||||||
num_input_tensors = writer_utils.get_subgraph(model_buffer).InputsLength()
|
num_input_tensors = writer_utils.get_subgraph(model_buffer).InputsLength()
|
||||||
input_metadata = [_metadata_fb.TensorMetadataT()] * num_input_tensors
|
input_metadata = [metadata_fb.TensorMetadataT()] * num_input_tensors
|
||||||
|
|
||||||
_fill_default_tensor_names(input_metadata,
|
_fill_default_tensor_names(input_metadata,
|
||||||
writer_utils.get_input_tensor_names(model_buffer))
|
writer_utils.get_input_tensor_names(model_buffer))
|
||||||
|
@ -224,12 +286,12 @@ def _create_metadata_buffer(
|
||||||
output_metadata = [m.create_metadata() for m in output_md]
|
output_metadata = [m.create_metadata() for m in output_md]
|
||||||
else:
|
else:
|
||||||
num_output_tensors = writer_utils.get_subgraph(model_buffer).OutputsLength()
|
num_output_tensors = writer_utils.get_subgraph(model_buffer).OutputsLength()
|
||||||
output_metadata = [_metadata_fb.TensorMetadataT()] * num_output_tensors
|
output_metadata = [metadata_fb.TensorMetadataT()] * num_output_tensors
|
||||||
_fill_default_tensor_names(output_metadata,
|
_fill_default_tensor_names(output_metadata,
|
||||||
writer_utils.get_output_tensor_names(model_buffer))
|
writer_utils.get_output_tensor_names(model_buffer))
|
||||||
|
|
||||||
# Create the subgraph metadata.
|
# Create the subgraph metadata.
|
||||||
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
|
subgraph_metadata = metadata_fb.SubGraphMetadataT()
|
||||||
subgraph_metadata.inputTensorMetadata = input_metadata
|
subgraph_metadata.inputTensorMetadata = input_metadata
|
||||||
subgraph_metadata.outputTensorMetadata = output_metadata
|
subgraph_metadata.outputTensorMetadata = output_metadata
|
||||||
|
|
||||||
|
@ -243,7 +305,7 @@ def _create_metadata_buffer(
|
||||||
b = flatbuffers.Builder(0)
|
b = flatbuffers.Builder(0)
|
||||||
b.Finish(
|
b.Finish(
|
||||||
model_metadata.Pack(b),
|
model_metadata.Pack(b),
|
||||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||||
return b.Output()
|
return b.Output()
|
||||||
|
|
||||||
|
|
||||||
|
@ -291,7 +353,7 @@ class MetadataWriter(object):
|
||||||
name=model_name, description=model_description)
|
name=model_name, description=model_description)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
color_space_types = _metadata_fb.ColorSpaceType
|
color_space_types = metadata_fb.ColorSpaceType
|
||||||
|
|
||||||
def add_feature_input(self,
|
def add_feature_input(self,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
@ -305,7 +367,7 @@ class MetadataWriter(object):
|
||||||
self,
|
self,
|
||||||
norm_mean: List[float],
|
norm_mean: List[float],
|
||||||
norm_std: List[float],
|
norm_std: List[float],
|
||||||
color_space_type: Optional[int] = _metadata_fb.ColorSpaceType.RGB,
|
color_space_type: Optional[int] = metadata_fb.ColorSpaceType.RGB,
|
||||||
name: str = _INPUT_IMAGE_NAME,
|
name: str = _INPUT_IMAGE_NAME,
|
||||||
description: str = _INPUT_IMAGE_DESCRIPTION) -> 'MetadataWriter':
|
description: str = _INPUT_IMAGE_DESCRIPTION) -> 'MetadataWriter':
|
||||||
"""Adds an input image metadata for the image input.
|
"""Adds an input image metadata for the image input.
|
||||||
|
@ -341,9 +403,6 @@ class MetadataWriter(object):
|
||||||
self._input_mds.append(input_md)
|
self._input_mds.append(input_md)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
|
||||||
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively'
|
|
||||||
|
|
||||||
def add_classification_output(
|
def add_classification_output(
|
||||||
self,
|
self,
|
||||||
labels: Optional[Labels] = None,
|
labels: Optional[Labels] = None,
|
||||||
|
@ -416,8 +475,7 @@ class MetadataWriter(object):
|
||||||
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||||
"""
|
"""
|
||||||
# Populates metadata and associated files into TFLite model buffer.
|
# Populates metadata and associated files into TFLite model buffer.
|
||||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
populator = metadata.MetadataPopulator.with_model_buffer(self._model_buffer)
|
||||||
self._model_buffer)
|
|
||||||
metadata_buffer = _create_metadata_buffer(
|
metadata_buffer = _create_metadata_buffer(
|
||||||
model_buffer=self._model_buffer,
|
model_buffer=self._model_buffer,
|
||||||
general_md=self._general_md,
|
general_md=self._general_md,
|
||||||
|
@ -429,7 +487,7 @@ class MetadataWriter(object):
|
||||||
populator.populate()
|
populator.populate()
|
||||||
tflite_content = populator.get_model_buffer()
|
tflite_content = populator.get_model_buffer()
|
||||||
|
|
||||||
displayer = _metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||||
metadata_json_content = displayer.get_metadata_json()
|
metadata_json_content = displayer.get_metadata_json()
|
||||||
|
|
||||||
return tflite_content, metadata_json_content
|
return tflite_content, metadata_json_content
|
||||||
|
@ -452,9 +510,7 @@ class MetadataWriter(object):
|
||||||
"""Stores calibration parameters in a csv file."""
|
"""Stores calibration parameters in a csv file."""
|
||||||
filepath = os.path.join(self._temp_folder.name, filename)
|
filepath = os.path.join(self._temp_folder.name, filename)
|
||||||
with open(filepath, 'w') as f:
|
with open(filepath, 'w') as f:
|
||||||
for idx, item in enumerate(calibrations):
|
for item in calibrations:
|
||||||
if idx != 0:
|
|
||||||
f.write('\n')
|
|
||||||
if item:
|
if item:
|
||||||
if item.scale is None or item.slope is None or item.offset is None:
|
if item.scale is None or item.slope is None or item.offset is None:
|
||||||
raise ValueError('scale, slope and offset values can not be set to '
|
raise ValueError('scale, slope and offset values can not be set to '
|
||||||
|
@ -463,6 +519,30 @@ class MetadataWriter(object):
|
||||||
f.write(f'{item.scale},{item.slope},{item.offset},{item.min_score}')
|
f.write(f'{item.scale},{item.slope},{item.offset},{item.min_score}')
|
||||||
else:
|
else:
|
||||||
f.write(f'{item.scale},{item.slope},{item.offset}')
|
f.write(f'{item.scale},{item.slope},{item.offset}')
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
self._associated_files.append(filepath)
|
self._associated_files.append(filepath)
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataWriterBase:
|
||||||
|
"""Base MetadataWriter class which contains the apis exposed to users.
|
||||||
|
|
||||||
|
MetadataWriter for Tasks e.g. image classifier / object detector will inherit
|
||||||
|
this class for their own usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, writer: MetadataWriter) -> None:
|
||||||
|
self.writer = writer
|
||||||
|
|
||||||
|
def populate(self) -> Tuple[bytearray, str]:
|
||||||
|
"""Populates metadata into the TFLite file.
|
||||||
|
|
||||||
|
Note that only the output tflite is used for deployment. The output JSON
|
||||||
|
content is used to interpret the metadata content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||||
|
"""
|
||||||
|
return self.writer.populate()
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,28 @@ py_test(
|
||||||
py_test(
|
py_test(
|
||||||
name = "metadata_writer_test",
|
name = "metadata_writer_test",
|
||||||
srcs = ["metadata_writer_test.py"],
|
srcs = ["metadata_writer_test.py"],
|
||||||
data = ["//mediapipe/tasks/testdata/metadata:model_files"],
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||||
|
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "image_classifier_test",
|
||||||
|
srcs = ["image_classifier_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||||
|
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||||
|
"//mediapipe/tasks/python/metadata",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for metadata_writer.image_classifier."""
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||||
|
from mediapipe.tasks.python.metadata import metadata
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_FLOAT_MODEL = test_utils.get_test_data_path(
|
||||||
|
"mobilenet_v2_1.0_224_without_metadata.tflite")
|
||||||
|
_QUANT_MODEL = test_utils.get_test_data_path(
|
||||||
|
"mobilenet_v2_1.0_224_quant_without_metadata.tflite")
|
||||||
|
_LABEL_FILE = test_utils.get_test_data_path("labels.txt")
|
||||||
|
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt")
|
||||||
|
_SCORE_CALIBRATION_FILENAME = "score_calibration.txt"
|
||||||
|
_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2
|
||||||
|
_NORM_MEAN = 127.5
|
||||||
|
_NORM_STD = 127.5
|
||||||
|
_FLOAT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224.json")
|
||||||
|
_QUANT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224_quant.json")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
"testcase_name": "test_float_model",
|
||||||
|
"model_file": _FLOAT_MODEL,
|
||||||
|
"golden_json": _FLOAT_JSON
|
||||||
|
}, {
|
||||||
|
"testcase_name": "test_quant_model",
|
||||||
|
"model_file": _QUANT_MODEL,
|
||||||
|
"golden_json": _QUANT_JSON
|
||||||
|
})
|
||||||
|
def test_write_metadata(self, model_file: str, golden_json: str):
|
||||||
|
with open(model_file, "rb") as f:
|
||||||
|
model_buffer = f.read()
|
||||||
|
writer = image_classifier.MetadataWriter.create(
|
||||||
|
model_buffer, [_NORM_MEAN], [_NORM_STD],
|
||||||
|
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||||
|
score_calibration=metadata_writer.ScoreCalibration.create_from_file(
|
||||||
|
metadata_fb.ScoreTransformationType.LOG, _SCORE_CALIBRATION_FILE,
|
||||||
|
_DEFAULT_SCORE_CALIBRATION_VALUE))
|
||||||
|
tflite_content, metadata_json = writer.populate()
|
||||||
|
|
||||||
|
with open(golden_json, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||||
|
file_buffer = displayer.get_associated_file_buffer(
|
||||||
|
_SCORE_CALIBRATION_FILENAME)
|
||||||
|
with open(_SCORE_CALIBRATION_FILE, "rb") as f:
|
||||||
|
expected_file_buffer = f.read()
|
||||||
|
self.assertEqual(file_buffer, expected_file_buffer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
absltest.main()
|
|
@ -13,6 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for metadata writer classes."""
|
"""Tests for metadata writer classes."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
@ -20,6 +23,7 @@ from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
_IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path(
|
_IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path(
|
||||||
'mobilenet_v1_0.25_224_1_default_1.tflite')
|
'mobilenet_v1_0.25_224_1_default_1.tflite')
|
||||||
|
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path('score_calibration.txt')
|
||||||
|
|
||||||
|
|
||||||
class LabelsTest(absltest.TestCase):
|
class LabelsTest(absltest.TestCase):
|
||||||
|
@ -49,6 +53,54 @@ class LabelsTest(absltest.TestCase):
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreCalibrationTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_create_from_file_successful(self):
|
||||||
|
score_calibration = metadata_writer.ScoreCalibration.create_from_file(
|
||||||
|
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||||
|
_SCORE_CALIBRATION_FILE)
|
||||||
|
self.assertLen(score_calibration.parameters, 511)
|
||||||
|
self.assertIsNone(score_calibration.parameters[0])
|
||||||
|
self.assertEqual(
|
||||||
|
score_calibration.parameters[1],
|
||||||
|
metadata_writer.CalibrationParameter(
|
||||||
|
scale=0.9876328110694885,
|
||||||
|
slope=0.36622241139411926,
|
||||||
|
offset=0.5352765321731567,
|
||||||
|
min_score=0.71484375))
|
||||||
|
self.assertEqual(
|
||||||
|
score_calibration.parameters[510],
|
||||||
|
metadata_writer.CalibrationParameter(
|
||||||
|
scale=0.9901729226112366,
|
||||||
|
slope=0.8561913371086121,
|
||||||
|
offset=0.8783953189849854,
|
||||||
|
min_score=0.5859375))
|
||||||
|
|
||||||
|
def test_create_from_file_fail(self):
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
test_file = os.path.join(temp_dir, 'score_calibration.csv')
|
||||||
|
with open(test_file, 'w') as f:
|
||||||
|
f.write('0.98,0.5\n')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Expected empty lines or 3 or 4 parameters per line in score '
|
||||||
|
'calibration file, but got 2.'
|
||||||
|
):
|
||||||
|
metadata_writer.ScoreCalibration.create_from_file(
|
||||||
|
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||||
|
test_file)
|
||||||
|
|
||||||
|
with open(test_file, 'w') as f:
|
||||||
|
f.write('-0.98,0.5,0.34\n')
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Expected scale to be a non-negative value, but got -0.98.'):
|
||||||
|
metadata_writer.ScoreCalibration.create_from_file(
|
||||||
|
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||||
|
test_file)
|
||||||
|
|
||||||
|
|
||||||
class MetadataWriterForTaskTest(absltest.TestCase):
|
class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -197,7 +249,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
"output_tensor_metadata": [
|
"output_tensor_metadata": [
|
||||||
{
|
{
|
||||||
"name": "score",
|
"name": "score",
|
||||||
"description": "Score of the labels respectively",
|
"description": "Score of the labels respectively.",
|
||||||
"content": {
|
"content": {
|
||||||
"content_properties_type": "FeatureProperties",
|
"content_properties_type": "FeatureProperties",
|
||||||
"content_properties": {
|
"content_properties": {
|
||||||
|
@ -298,7 +350,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
"output_tensor_metadata": [
|
"output_tensor_metadata": [
|
||||||
{
|
{
|
||||||
"name": "score",
|
"name": "score",
|
||||||
"description": "Score of the labels respectively",
|
"description": "Score of the labels respectively.",
|
||||||
"content": {
|
"content": {
|
||||||
"content_properties_type": "FeatureProperties",
|
"content_properties_type": "FeatureProperties",
|
||||||
"content_properties": {
|
"content_properties": {
|
||||||
|
|
10
mediapipe/tasks/testdata/metadata/BUILD
vendored
10
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -29,6 +29,8 @@ mediapipe_files(srcs = [
|
||||||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||||
"mobilenet_v2_1.0_224_quant.tflite",
|
"mobilenet_v2_1.0_224_quant.tflite",
|
||||||
|
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||||
|
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||||
])
|
])
|
||||||
|
|
||||||
exports_files([
|
exports_files([
|
||||||
|
@ -48,6 +50,9 @@ exports_files([
|
||||||
"score_calibration.txt",
|
"score_calibration.txt",
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
|
"labels.txt",
|
||||||
|
"mobilenet_v2_1.0_224.json",
|
||||||
|
"mobilenet_v2_1.0_224_quant.json",
|
||||||
])
|
])
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -59,6 +64,8 @@ filegroup(
|
||||||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||||
"mobilenet_v2_1.0_224_quant.tflite",
|
"mobilenet_v2_1.0_224_quant.tflite",
|
||||||
|
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||||
|
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -78,6 +85,9 @@ filegroup(
|
||||||
"input_image_tensor_float_meta.json",
|
"input_image_tensor_float_meta.json",
|
||||||
"input_image_tensor_uint8_meta.json",
|
"input_image_tensor_uint8_meta.json",
|
||||||
"input_image_tensor_unsupported_meta.json",
|
"input_image_tensor_unsupported_meta.json",
|
||||||
|
"labels.txt",
|
||||||
|
"mobilenet_v2_1.0_224.json",
|
||||||
|
"mobilenet_v2_1.0_224_quant.json",
|
||||||
"score_calibration.txt",
|
"score_calibration.txt",
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
|
|
1001
mediapipe/tasks/testdata/metadata/labels.txt
vendored
Normal file
1001
mediapipe/tasks/testdata/metadata/labels.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json
vendored
Normal file
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json
vendored
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
{
|
||||||
|
"name": "ImageClassifier",
|
||||||
|
"description": "Identify the most prominent object in the image from a known set of categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "image",
|
||||||
|
"description": "Input image to be processed.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "ImageProperties",
|
||||||
|
"content_properties": {
|
||||||
|
"color_space": "RGB"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "NormalizationOptions",
|
||||||
|
"options": {
|
||||||
|
"mean": [
|
||||||
|
127.5
|
||||||
|
],
|
||||||
|
"std": [
|
||||||
|
127.5
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
-1.0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "ScoreCalibrationOptions",
|
||||||
|
"options": {
|
||||||
|
"score_transformation": "LOG",
|
||||||
|
"default_score": 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "score_calibration.txt",
|
||||||
|
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
|
||||||
|
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.0.0"
|
||||||
|
}
|
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json
vendored
Normal file
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json
vendored
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
{
|
||||||
|
"name": "ImageClassifier",
|
||||||
|
"description": "Identify the most prominent object in the image from a known set of categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "image",
|
||||||
|
"description": "Input image to be processed.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "ImageProperties",
|
||||||
|
"content_properties": {
|
||||||
|
"color_space": "RGB"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "NormalizationOptions",
|
||||||
|
"options": {
|
||||||
|
"mean": [
|
||||||
|
127.5
|
||||||
|
],
|
||||||
|
"std": [
|
||||||
|
127.5
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
255.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "ScoreCalibrationOptions",
|
||||||
|
"options": {
|
||||||
|
"score_transformation": "LOG",
|
||||||
|
"default_score": 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
255.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "score_calibration.txt",
|
||||||
|
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
|
||||||
|
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.0.0"
|
||||||
|
}
|
30
third_party/external_files.bzl
vendored
30
third_party/external_files.bzl
vendored
|
@ -364,6 +364,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/knift_labelmap.txt?generation=1661875792821628"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/knift_labelmap.txt?generation=1661875792821628"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_labels_txt",
|
||||||
|
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1665988394538324"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_left_hands_jpg",
|
name = "com_google_mediapipe_left_hands_jpg",
|
||||||
sha256 = "4b5134daa4cb60465535239535f9f74c2842aba3aa5fd30bf04ef5678f93d87f",
|
sha256 = "4b5134daa4cb60465535239535f9f74c2842aba3aa5fd30bf04ef5678f93d87f",
|
||||||
|
@ -448,18 +454,42 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_json",
|
||||||
|
sha256 = "0eb285a857b4bb1815736d0902ace0af45ea62e90c1dac98844b9ca797cd0d7b",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1665988398778178"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json",
|
||||||
|
sha256 = "932f345ebe3d98daf0dc4c88b0f9e694e450390fb394fc217e851338dfec43e6",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1665988401522527"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite",
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite",
|
||||||
sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
|
sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_without_metadata_tflite",
|
||||||
|
sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant_without_metadata.tflite?generation=1665988405130772"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite",
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite",
|
||||||
sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339",
|
sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.tflite?generation=1661875840611150"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.tflite?generation=1661875840611150"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_without_metadata_tflite",
|
||||||
|
sha256 = "9f3bc29e38e90842a852bfed957dbf5e36f2d97a91dd17736b1e5c0aca8d3303",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_without_metadata.tflite?generation=1665988408360823"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_mobilenet_v3_small_100_224_embedder_tflite",
|
name = "com_google_mediapipe_mobilenet_v3_small_100_224_embedder_tflite",
|
||||||
sha256 = "f7b9a563cb803bdcba76e8c7e82abde06f5c7a8e67b5e54e43e23095dfe79a78",
|
sha256 = "f7b9a563cb803bdcba76e8c7e82abde06f5c7a8e67b5e54e43e23095dfe79a78",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user