Metadata Writer: Add Metadata Writer for image classifier.
PiperOrigin-RevId: 483282627
This commit is contained in:
parent
ec2a34d2a4
commit
ab17be9294
|
@ -37,3 +37,9 @@ py_library(
|
|||
srcs = ["writer_utils.py"],
|
||||
deps = ["//mediapipe/tasks/metadata:schema_py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier",
|
||||
srcs = ["image_classifier.py"],
|
||||
deps = [":metadata_writer"],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Writes metadata and label file to the image classifier models."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
_MODEL_NAME = "ImageClassifier"
|
||||
_MODEL_DESCRIPTION = ("Identify the most prominent object in the image from a "
|
||||
"known set of categories.")
|
||||
|
||||
|
||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||
"""MetadataWriter to write the metadata for image classifier."""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_buffer: bytearray,
|
||||
input_norm_mean: List[float],
|
||||
input_norm_std: List[float],
|
||||
labels: metadata_writer.Labels,
|
||||
score_calibration: Optional[metadata_writer.ScoreCalibration] = None
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for image classifier.
|
||||
|
||||
The parameters required in this method are mandatory when using MediaPipe
|
||||
Tasks.
|
||||
|
||||
Note that only the output TFLite is used for deployment. The output JSON
|
||||
content is used to interpret the metadata content.
|
||||
|
||||
Args:
|
||||
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||
input_norm_mean: the mean value used in the input tensor normalization
|
||||
[1].
|
||||
input_norm_std: the std value used in the input tensor normalizarion [1].
|
||||
labels: an instance of Labels helper class used in the output
|
||||
classification tensor [2].
|
||||
score_calibration: A container of the score calibration operation [3] in
|
||||
the classification tensor. Optional if the model does not use score
|
||||
calibration.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
|
||||
Returns:
|
||||
An MetadataWrite object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||
writer.add_classification_output(labels, score_calibration)
|
||||
return cls(writer)
|
|
@ -15,19 +15,22 @@
|
|||
"""Generic metadata writer."""
|
||||
|
||||
import collections
|
||||
import csv
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.python.metadata import metadata as _metadata
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||
|
||||
_INPUT_IMAGE_NAME = 'image'
|
||||
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
|
||||
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
||||
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -140,26 +143,85 @@ class Labels(object):
|
|||
class ScoreCalibration:
|
||||
"""Simple container holding score calibration related parameters."""
|
||||
|
||||
# A shortcut to avoid client side code importing _metadata_fb
|
||||
transformation_types = _metadata_fb.ScoreTransformationType
|
||||
# A shortcut to avoid client side code importing metadata_fb
|
||||
transformation_types = metadata_fb.ScoreTransformationType
|
||||
|
||||
def __init__(self,
|
||||
transformation_type: _metadata_fb.ScoreTransformationType,
|
||||
parameters: List[CalibrationParameter],
|
||||
transformation_type: metadata_fb.ScoreTransformationType,
|
||||
parameters: List[Optional[CalibrationParameter]],
|
||||
default_score: int = 0):
|
||||
self.transformation_type = transformation_type
|
||||
self.parameters = parameters
|
||||
self.default_score = default_score
|
||||
|
||||
@classmethod
|
||||
def create_from_file(cls,
|
||||
transformation_type: metadata_fb.ScoreTransformationType,
|
||||
file_path: str,
|
||||
default_score: int = 0) -> 'ScoreCalibration':
|
||||
"""Creates ScoreCalibration from the file.
|
||||
|
||||
Args:
|
||||
transformation_type: type of the function used for transforming the
|
||||
uncalibrated score before applying score calibration.
|
||||
file_path: file_path of the score calibration file [1]. Contains
|
||||
sigmoid-based score calibration parameters, formatted as CSV. Lines
|
||||
contain for each index of an output tensor the scale, slope, offset and
|
||||
(optional) min_score parameters to be used for sigmoid fitting (in this
|
||||
order and in `strtof`-compatible [2] format). Scale should be a
|
||||
non-negative value. A line may be left empty to default calibrated
|
||||
scores for this index to default_score. In summary, each line should
|
||||
thus contain 0, 3 or 4 comma-separated values.
|
||||
default_score: the default calibrated score to apply if the uncalibrated
|
||||
score is below min_score or if no parameters were specified for a given
|
||||
index.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L133
|
||||
[2]:
|
||||
https://en.cppreference.com/w/c/string/byte/strtof
|
||||
|
||||
Returns:
|
||||
A ScoreCalibration object.
|
||||
Raises:
|
||||
ValueError: if the score_calibration file is malformed.
|
||||
"""
|
||||
with open(file_path, 'r') as calibration_file:
|
||||
csv_reader = csv.reader(calibration_file, delimiter=',')
|
||||
parameters = []
|
||||
for row in csv_reader:
|
||||
if not row:
|
||||
parameters.append(None)
|
||||
continue
|
||||
|
||||
if len(row) != 3 and len(row) != 4:
|
||||
raise ValueError(
|
||||
f'Expected empty lines or 3 or 4 parameters per line in score'
|
||||
f' calibration file, but got {len(row)}.')
|
||||
|
||||
if float(row[0]) < 0:
|
||||
raise ValueError(
|
||||
f'Expected scale to be a non-negative value, but got '
|
||||
f'{float(row[0])}.')
|
||||
|
||||
parameters.append(
|
||||
CalibrationParameter(
|
||||
scale=float(row[0]),
|
||||
slope=float(row[1]),
|
||||
offset=float(row[2]),
|
||||
min_score=None if len(row) == 3 else float(row[3])))
|
||||
|
||||
return cls(transformation_type, parameters, default_score)
|
||||
|
||||
|
||||
def _fill_default_tensor_names(
|
||||
tensor_metadata: List[_metadata_fb.TensorMetadataT],
|
||||
tensor_metadata_list: List[metadata_fb.TensorMetadataT],
|
||||
tensor_names_from_model: List[str]):
|
||||
"""Fills the default tensor names."""
|
||||
# If tensor name in metadata is empty, default to the tensor name saved in
|
||||
# the model.
|
||||
for metadata, name in zip(tensor_metadata, tensor_names_from_model):
|
||||
metadata.name = metadata.name or name
|
||||
for tensor_metadata, name in zip(tensor_metadata_list,
|
||||
tensor_names_from_model):
|
||||
tensor_metadata.name = tensor_metadata.name or name
|
||||
|
||||
|
||||
def _pair_tensor_metadata(
|
||||
|
@ -212,7 +274,7 @@ def _create_metadata_buffer(
|
|||
input_metadata = [m.create_metadata() for m in input_md]
|
||||
else:
|
||||
num_input_tensors = writer_utils.get_subgraph(model_buffer).InputsLength()
|
||||
input_metadata = [_metadata_fb.TensorMetadataT()] * num_input_tensors
|
||||
input_metadata = [metadata_fb.TensorMetadataT()] * num_input_tensors
|
||||
|
||||
_fill_default_tensor_names(input_metadata,
|
||||
writer_utils.get_input_tensor_names(model_buffer))
|
||||
|
@ -224,12 +286,12 @@ def _create_metadata_buffer(
|
|||
output_metadata = [m.create_metadata() for m in output_md]
|
||||
else:
|
||||
num_output_tensors = writer_utils.get_subgraph(model_buffer).OutputsLength()
|
||||
output_metadata = [_metadata_fb.TensorMetadataT()] * num_output_tensors
|
||||
output_metadata = [metadata_fb.TensorMetadataT()] * num_output_tensors
|
||||
_fill_default_tensor_names(output_metadata,
|
||||
writer_utils.get_output_tensor_names(model_buffer))
|
||||
|
||||
# Create the subgraph metadata.
|
||||
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph_metadata = metadata_fb.SubGraphMetadataT()
|
||||
subgraph_metadata.inputTensorMetadata = input_metadata
|
||||
subgraph_metadata.outputTensorMetadata = output_metadata
|
||||
|
||||
|
@ -243,7 +305,7 @@ def _create_metadata_buffer(
|
|||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_metadata.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
return b.Output()
|
||||
|
||||
|
||||
|
@ -291,7 +353,7 @@ class MetadataWriter(object):
|
|||
name=model_name, description=model_description)
|
||||
return self
|
||||
|
||||
color_space_types = _metadata_fb.ColorSpaceType
|
||||
color_space_types = metadata_fb.ColorSpaceType
|
||||
|
||||
def add_feature_input(self,
|
||||
name: Optional[str] = None,
|
||||
|
@ -305,7 +367,7 @@ class MetadataWriter(object):
|
|||
self,
|
||||
norm_mean: List[float],
|
||||
norm_std: List[float],
|
||||
color_space_type: Optional[int] = _metadata_fb.ColorSpaceType.RGB,
|
||||
color_space_type: Optional[int] = metadata_fb.ColorSpaceType.RGB,
|
||||
name: str = _INPUT_IMAGE_NAME,
|
||||
description: str = _INPUT_IMAGE_DESCRIPTION) -> 'MetadataWriter':
|
||||
"""Adds an input image metadata for the image input.
|
||||
|
@ -341,9 +403,6 @@ class MetadataWriter(object):
|
|||
self._input_mds.append(input_md)
|
||||
return self
|
||||
|
||||
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
||||
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively'
|
||||
|
||||
def add_classification_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
|
@ -416,8 +475,7 @@ class MetadataWriter(object):
|
|||
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||
"""
|
||||
# Populates metadata and associated files into TFLite model buffer.
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._model_buffer)
|
||||
populator = metadata.MetadataPopulator.with_model_buffer(self._model_buffer)
|
||||
metadata_buffer = _create_metadata_buffer(
|
||||
model_buffer=self._model_buffer,
|
||||
general_md=self._general_md,
|
||||
|
@ -429,7 +487,7 @@ class MetadataWriter(object):
|
|||
populator.populate()
|
||||
tflite_content = populator.get_model_buffer()
|
||||
|
||||
displayer = _metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||
metadata_json_content = displayer.get_metadata_json()
|
||||
|
||||
return tflite_content, metadata_json_content
|
||||
|
@ -452,9 +510,7 @@ class MetadataWriter(object):
|
|||
"""Stores calibration parameters in a csv file."""
|
||||
filepath = os.path.join(self._temp_folder.name, filename)
|
||||
with open(filepath, 'w') as f:
|
||||
for idx, item in enumerate(calibrations):
|
||||
if idx != 0:
|
||||
f.write('\n')
|
||||
for item in calibrations:
|
||||
if item:
|
||||
if item.scale is None or item.slope is None or item.offset is None:
|
||||
raise ValueError('scale, slope and offset values can not be set to '
|
||||
|
@ -463,6 +519,30 @@ class MetadataWriter(object):
|
|||
f.write(f'{item.scale},{item.slope},{item.offset},{item.min_score}')
|
||||
else:
|
||||
f.write(f'{item.scale},{item.slope},{item.offset}')
|
||||
f.write('\n')
|
||||
|
||||
self._associated_files.append(filepath)
|
||||
self._associated_files.append(filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
class MetadataWriterBase:
|
||||
"""Base MetadataWriter class which contains the apis exposed to users.
|
||||
|
||||
MetadataWriter for Tasks e.g. image classifier / object detector will inherit
|
||||
this class for their own usage.
|
||||
"""
|
||||
|
||||
def __init__(self, writer: MetadataWriter) -> None:
|
||||
self.writer = writer
|
||||
|
||||
def populate(self) -> Tuple[bytearray, str]:
|
||||
"""Populates metadata into the TFLite file.
|
||||
|
||||
Note that only the output tflite is used for deployment. The output JSON
|
||||
content is used to interpret the metadata content.
|
||||
|
||||
Returns:
|
||||
A tuple of (model_with_metadata_in_bytes, metdata_json_content)
|
||||
"""
|
||||
return self.writer.populate()
|
||||
|
||||
|
|
|
@ -28,9 +28,28 @@ py_test(
|
|||
py_test(
|
||||
name = "metadata_writer_test",
|
||||
srcs = ["metadata_writer_test.py"],
|
||||
data = ["//mediapipe/tasks/testdata/metadata:model_files"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_classifier_test",
|
||||
srcs = ["image_classifier_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata_writer.image_classifier."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_FLOAT_MODEL = test_utils.get_test_data_path(
|
||||
"mobilenet_v2_1.0_224_without_metadata.tflite")
|
||||
_QUANT_MODEL = test_utils.get_test_data_path(
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite")
|
||||
_LABEL_FILE = test_utils.get_test_data_path("labels.txt")
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt")
|
||||
_SCORE_CALIBRATION_FILENAME = "score_calibration.txt"
|
||||
_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2
|
||||
_NORM_MEAN = 127.5
|
||||
_NORM_STD = 127.5
|
||||
_FLOAT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224.json")
|
||||
_QUANT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224_quant.json")
|
||||
|
||||
|
||||
class ImageClassifierTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "test_float_model",
|
||||
"model_file": _FLOAT_MODEL,
|
||||
"golden_json": _FLOAT_JSON
|
||||
}, {
|
||||
"testcase_name": "test_quant_model",
|
||||
"model_file": _QUANT_MODEL,
|
||||
"golden_json": _QUANT_JSON
|
||||
})
|
||||
def test_write_metadata(self, model_file: str, golden_json: str):
|
||||
with open(model_file, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = image_classifier.MetadataWriter.create(
|
||||
model_buffer, [_NORM_MEAN], [_NORM_STD],
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||
score_calibration=metadata_writer.ScoreCalibration.create_from_file(
|
||||
metadata_fb.ScoreTransformationType.LOG, _SCORE_CALIBRATION_FILE,
|
||||
_DEFAULT_SCORE_CALIBRATION_VALUE))
|
||||
tflite_content, metadata_json = writer.populate()
|
||||
|
||||
with open(golden_json, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||
file_buffer = displayer.get_associated_file_buffer(
|
||||
_SCORE_CALIBRATION_FILENAME)
|
||||
with open(_SCORE_CALIBRATION_FILE, "rb") as f:
|
||||
expected_file_buffer = f.read()
|
||||
self.assertEqual(file_buffer, expected_file_buffer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
|
@ -13,6 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata writer classes."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
@ -20,6 +23,7 @@ from mediapipe.tasks.python.test import test_utils
|
|||
|
||||
_IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path(
|
||||
'mobilenet_v1_0.25_224_1_default_1.tflite')
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path('score_calibration.txt')
|
||||
|
||||
|
||||
class LabelsTest(absltest.TestCase):
|
||||
|
@ -49,6 +53,54 @@ class LabelsTest(absltest.TestCase):
|
|||
])
|
||||
|
||||
|
||||
class ScoreCalibrationTest(absltest.TestCase):
|
||||
|
||||
def test_create_from_file_successful(self):
|
||||
score_calibration = metadata_writer.ScoreCalibration.create_from_file(
|
||||
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||
_SCORE_CALIBRATION_FILE)
|
||||
self.assertLen(score_calibration.parameters, 511)
|
||||
self.assertIsNone(score_calibration.parameters[0])
|
||||
self.assertEqual(
|
||||
score_calibration.parameters[1],
|
||||
metadata_writer.CalibrationParameter(
|
||||
scale=0.9876328110694885,
|
||||
slope=0.36622241139411926,
|
||||
offset=0.5352765321731567,
|
||||
min_score=0.71484375))
|
||||
self.assertEqual(
|
||||
score_calibration.parameters[510],
|
||||
metadata_writer.CalibrationParameter(
|
||||
scale=0.9901729226112366,
|
||||
slope=0.8561913371086121,
|
||||
offset=0.8783953189849854,
|
||||
min_score=0.5859375))
|
||||
|
||||
def test_create_from_file_fail(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_file = os.path.join(temp_dir, 'score_calibration.csv')
|
||||
with open(test_file, 'w') as f:
|
||||
f.write('0.98,0.5\n')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Expected empty lines or 3 or 4 parameters per line in score '
|
||||
'calibration file, but got 2.'
|
||||
):
|
||||
metadata_writer.ScoreCalibration.create_from_file(
|
||||
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||
test_file)
|
||||
|
||||
with open(test_file, 'w') as f:
|
||||
f.write('-0.98,0.5,0.34\n')
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Expected scale to be a non-negative value, but got -0.98.'):
|
||||
metadata_writer.ScoreCalibration.create_from_file(
|
||||
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||
test_file)
|
||||
|
||||
|
||||
class MetadataWriterForTaskTest(absltest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -197,7 +249,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
|
@ -298,7 +350,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
|
|
10
mediapipe/tasks/testdata/metadata/BUILD
vendored
10
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -29,6 +29,8 @@ mediapipe_files(srcs = [
|
|||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||
])
|
||||
|
||||
exports_files([
|
||||
|
@ -48,6 +50,9 @@ exports_files([
|
|||
"score_calibration.txt",
|
||||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"labels.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
|
@ -59,6 +64,8 @@ filegroup(
|
|||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -78,6 +85,9 @@ filegroup(
|
|||
"input_image_tensor_float_meta.json",
|
||||
"input_image_tensor_uint8_meta.json",
|
||||
"input_image_tensor_unsupported_meta.json",
|
||||
"labels.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
"score_calibration.txt",
|
||||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
|
|
1001
mediapipe/tasks/testdata/metadata/labels.txt
vendored
Normal file
1001
mediapipe/tasks/testdata/metadata/labels.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json
vendored
Normal file
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json
vendored
Normal file
|
@ -0,0 +1,82 @@
|
|||
{
|
||||
"name": "ImageClassifier",
|
||||
"description": "Identify the most prominent object in the image from a known set of categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
-1.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "ScoreCalibrationOptions",
|
||||
"options": {
|
||||
"score_transformation": "LOG",
|
||||
"default_score": 0.2
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
},
|
||||
{
|
||||
"name": "score_calibration.txt",
|
||||
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
|
||||
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json
vendored
Normal file
82
mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json
vendored
Normal file
|
@ -0,0 +1,82 @@
|
|||
{
|
||||
"name": "ImageClassifier",
|
||||
"description": "Identify the most prominent object in the image from a known set of categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
255.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "ScoreCalibrationOptions",
|
||||
"options": {
|
||||
"score_transformation": "LOG",
|
||||
"default_score": 0.2
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
255.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
},
|
||||
{
|
||||
"name": "score_calibration.txt",
|
||||
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
|
||||
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
30
third_party/external_files.bzl
vendored
30
third_party/external_files.bzl
vendored
|
@ -364,6 +364,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/knift_labelmap.txt?generation=1661875792821628"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_labels_txt",
|
||||
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1665988394538324"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_left_hands_jpg",
|
||||
sha256 = "4b5134daa4cb60465535239535f9f74c2842aba3aa5fd30bf04ef5678f93d87f",
|
||||
|
@ -448,18 +454,42 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_json",
|
||||
sha256 = "0eb285a857b4bb1815736d0902ace0af45ea62e90c1dac98844b9ca797cd0d7b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1665988398778178"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json",
|
||||
sha256 = "932f345ebe3d98daf0dc4c88b0f9e694e450390fb394fc217e851338dfec43e6",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1665988401522527"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite",
|
||||
sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_without_metadata_tflite",
|
||||
sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant_without_metadata.tflite?generation=1665988405130772"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite",
|
||||
sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.tflite?generation=1661875840611150"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_without_metadata_tflite",
|
||||
sha256 = "9f3bc29e38e90842a852bfed957dbf5e36f2d97a91dd17736b1e5c0aca8d3303",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_without_metadata.tflite?generation=1665988408360823"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v3_small_100_224_embedder_tflite",
|
||||
sha256 = "f7b9a563cb803bdcba76e8c7e82abde06f5c7a8e67b5e54e43e23095dfe79a78",
|
||||
|
|
Loading…
Reference in New Issue
Block a user