Add metadata writer to image_classifier in model_maker
PiperOrigin-RevId: 485926985
This commit is contained in:
parent
b472d8ff66
commit
d29c3d7512
|
@ -52,6 +52,7 @@ class CustomModel(abc.ABC):
|
|||
"""Prints a summary of the model."""
|
||||
self._model.summary()
|
||||
|
||||
# TODO: Remove this method when all tasks use Metadata writer
|
||||
def export_tflite(
|
||||
self,
|
||||
export_dir: str,
|
||||
|
@ -62,7 +63,7 @@ class CustomModel(abc.ABC):
|
|||
|
||||
Args:
|
||||
export_dir: The directory to save exported files.
|
||||
tflite_filename: File name to save tflite model. The full export path is
|
||||
tflite_filename: File name to save TFLite model. The full export path is
|
||||
{export_dir}/{tflite_filename}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
|
@ -73,11 +74,11 @@ class CustomModel(abc.ABC):
|
|||
tf.io.gfile.makedirs(export_dir)
|
||||
|
||||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||
# TODO: Populate metadata to the exported TFLite model.
|
||||
model_util.export_tflite(
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=self._model,
|
||||
tflite_filepath=tflite_filepath,
|
||||
quantization_config=quantization_config,
|
||||
preprocess=preprocess)
|
||||
model_util.save_tflite(
|
||||
tflite_model=tflite_model, tflite_file=tflite_filepath)
|
||||
tf.compat.v1.logging.info(
|
||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
||||
|
|
|
@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
|||
return len(train_data) // batch_size
|
||||
|
||||
|
||||
def export_tflite(
|
||||
def convert_to_tflite(
|
||||
model: tf.keras.Model,
|
||||
tflite_filepath: str,
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
supported_ops: Tuple[tf.lite.OpsSet,
|
||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
||||
preprocess: Optional[Callable[..., bool]] = None):
|
||||
"""Converts the model to tflite format and saves it.
|
||||
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
|
||||
"""Converts the input Keras model to TFLite format.
|
||||
|
||||
Args:
|
||||
model: model to be converted to tflite.
|
||||
tflite_filepath: File path to save tflite model.
|
||||
model: Keras model to be converted to TFLite.
|
||||
quantization_config: Configuration for post-training quantization.
|
||||
supported_ops: A list of supported ops in the converted TFLite file.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
quantization. The callable takes three arguments in order: feature, label,
|
||||
and is_training.
|
||||
"""
|
||||
if tflite_filepath is None:
|
||||
raise ValueError(
|
||||
"TFLite filepath couldn't be None when exporting to tflite.")
|
||||
|
||||
Returns:
|
||||
bytearray of TFLite model
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, 'saved_model')
|
||||
model.save(save_path, include_optimizer=False, save_format='tf')
|
||||
|
@ -122,9 +119,22 @@ def export_tflite(
|
|||
|
||||
converter.target_spec.supported_ops = supported_ops
|
||||
tflite_model = converter.convert()
|
||||
return tflite_model
|
||||
|
||||
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
|
||||
|
||||
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
|
||||
"""Saves TFLite file to tflite_file.
|
||||
|
||||
Args:
|
||||
tflite_model: A valid flatbuffer representing the TFLite model.
|
||||
tflite_file: File path to save TFLite model.
|
||||
"""
|
||||
if tflite_file is None:
|
||||
raise ValueError("TFLite filepath can't be None when exporting to TFLite.")
|
||||
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
|
||||
f.write(tflite_model)
|
||||
tf.compat.v1.logging.info(
|
||||
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
|
||||
|
||||
|
||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
|
@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
|||
class LiteRunner(object):
|
||||
"""A runner to do inference with the TFLite model."""
|
||||
|
||||
def __init__(self, tflite_filepath: str):
|
||||
"""Initializes Lite runner with tflite model file.
|
||||
def __init__(self, tflite_model: bytearray):
|
||||
"""Initializes Lite runner from TFLite model buffer.
|
||||
|
||||
Args:
|
||||
tflite_filepath: File path to the TFLite model.
|
||||
tflite_model: A valid flatbuffer representing the TFLite model.
|
||||
"""
|
||||
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
|
||||
tflite_model = f.read()
|
||||
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
self.interpreter.allocate_tensors()
|
||||
self.input_details = self.interpreter.get_input_details()
|
||||
|
@ -250,9 +258,9 @@ class LiteRunner(object):
|
|||
return output_tensors
|
||||
|
||||
|
||||
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
|
||||
"""Returns a `LiteRunner` from file path to TFLite model."""
|
||||
lite_runner = LiteRunner(tflite_filepath)
|
||||
def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
|
||||
"""Returns a `LiteRunner` from flatbuffer of the TFLite model."""
|
||||
lite_runner = LiteRunner(tflite_buffer)
|
||||
return lite_runner
|
||||
|
||||
|
||||
|
|
|
@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
'name': 'test'
|
||||
})
|
||||
|
||||
def test_export_tflite(self):
|
||||
def test_convert_to_tflite(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.export_tflite(model, tflite_file)
|
||||
tflite_model = model_util.convert_to_tflite(model)
|
||||
test_util.test_tflite(
|
||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||
keras_model=model, tflite_model=tflite_model, size=[1, input_dim])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
|
@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
testcase_name='float16_quantize',
|
||||
config=quantization.QuantizationConfig.for_float16(),
|
||||
model_size=1468))
|
||||
def test_export_tflite_quantized(self, config, model_size):
|
||||
def test_convert_to_tflite_quantized(self, config, model_size):
|
||||
input_dim = 16
|
||||
num_classes = 2
|
||||
max_input_value = 5
|
||||
model = test_util.build_model(
|
||||
input_shape=[input_dim], num_classes=num_classes)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||
|
||||
model_util.export_tflite(
|
||||
model=model, tflite_filepath=tflite_file, quantization_config=config)
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=model, quantization_config=config)
|
||||
self.assertTrue(
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
tflite_file=tflite_file,
|
||||
tflite_model=tflite_model,
|
||||
size=[1, input_dim],
|
||||
high=max_input_value,
|
||||
atol=1e-00))
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
self.assertNear(len(tflite_model), model_size, 300)
|
||||
|
||||
def test_save_tflite(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_model = model_util.convert_to_tflite(model)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||
test_util.test_tflite_file(
|
||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
|
@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
|
|||
return model
|
||||
|
||||
|
||||
def is_same_output(tflite_file: str,
|
||||
def is_same_output(tflite_model: bytearray,
|
||||
keras_model: tf.keras.Model,
|
||||
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Returns if the output of TFLite model and keras model are identical."""
|
||||
# Gets output from lite model.
|
||||
lite_runner = model_util.get_lite_runner(tflite_file)
|
||||
lite_runner = model_util.get_lite_runner(tflite_model)
|
||||
lite_output = lite_runner.run(input_tensors)
|
||||
|
||||
# Gets output from keras model.
|
||||
|
@ -95,12 +95,41 @@ def is_same_output(tflite_file: str,
|
|||
|
||||
|
||||
def test_tflite(keras_model: tf.keras.Model,
|
||||
tflite_file: str,
|
||||
tflite_model: bytearray,
|
||||
size: Union[int, List[int]],
|
||||
high: float = 1,
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||
|
||||
Args:
|
||||
keras_model: Input TensorFlow Keras model.
|
||||
tflite_model: Input TFLite model flatbuffer.
|
||||
size: Size of the input tesnor.
|
||||
high: Higher boundary of the values in input tensors.
|
||||
atol: Absolute tolerance of the difference between the outputs of Keras
|
||||
model and TFLite model.
|
||||
|
||||
Returns:
|
||||
True if the output of TFLite model and TF Keras model are identical.
|
||||
Otherwise, False.
|
||||
"""
|
||||
random_input = create_random_sample(size=size, high=high)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
return is_same_output(
|
||||
tflite_model=tflite_model,
|
||||
keras_model=keras_model,
|
||||
input_tensors=random_input,
|
||||
atol=atol)
|
||||
|
||||
|
||||
def test_tflite_file(keras_model: tf.keras.Model,
|
||||
tflite_file: bytearray,
|
||||
size: Union[int, List[int]],
|
||||
high: float = 1,
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||
|
||||
Args:
|
||||
keras_model: Input TensorFlow Keras model.
|
||||
tflite_file: Input TFLite model file.
|
||||
|
@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model,
|
|||
True if the output of TFLite model and TF Keras model are identical.
|
||||
Otherwise, False.
|
||||
"""
|
||||
random_input = create_random_sample(size=size, high=high)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
return is_same_output(
|
||||
tflite_file=tflite_file,
|
||||
keras_model=keras_model,
|
||||
input_tensors=random_input,
|
||||
atol=atol)
|
||||
with tf.io.gfile.GFile(tflite_file, "rb") as f:
|
||||
tflite_model = f.read()
|
||||
return test_tflite(keras_model, tflite_model, size, high, atol)
|
||||
|
|
|
@ -81,6 +81,8 @@ py_library(
|
|||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -88,7 +90,11 @@ py_library(
|
|||
name = "image_classifier_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["image_classifier_test.py"],
|
||||
deps = [":image_classifier_import"],
|
||||
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
|
||||
deps = [
|
||||
":image_classifier_import",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""APIs to train image classifier model."""
|
||||
import os
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing
|
|||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
|
||||
class ImageClassifier(classifier.Classifier):
|
||||
|
@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier):
|
|||
self,
|
||||
model_name: str = 'model.tflite',
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||
"""Converts the model to the requested formats and exports to a file.
|
||||
"""Converts and saves the model to a TFLite file with metadata included.
|
||||
|
||||
Note that only the TFLite file is needed for deployment. This function also
|
||||
saves a metadata.json file to the same directory as the TFLite file which
|
||||
can be used to interpret the metadata content in the TFLite file.
|
||||
|
||||
Args:
|
||||
model_name: File name to save tflite model. The full export path is
|
||||
{export_dir}/{tflite_filename}.
|
||||
model_name: File name to save TFLite model with metadata. The full export
|
||||
path is {self._hparams.model_dir}/{model_name}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
"""
|
||||
super().export_tflite(
|
||||
self._hparams.model_dir,
|
||||
model_name,
|
||||
quantization_config,
|
||||
if not tf.io.gfile.exists(self._hparams.model_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.model_dir)
|
||||
tflite_file = os.path.join(self._hparams.model_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
|
||||
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=self._model,
|
||||
quantization_config=quantization_config,
|
||||
preprocess=self._preprocess)
|
||||
writer = image_classifier_writer.MetadataWriter.create(
|
||||
tflite_model,
|
||||
self._model_spec.mean_rgb,
|
||||
self._model_spec.stddev_rgb,
|
||||
labels=metadata_writer.Labels().add(self._label_names))
|
||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||
with open(metadata_file, 'w') as f:
|
||||
f.write(metadata_json)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import filecmp
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
@ -19,6 +20,7 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision import image_classifier
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
def _fill_image(rgb, image_size):
|
||||
|
@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
validation_data=self.test_data)
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
|
||||
def test_efficientnetlite0_model_train_and_export(self):
|
||||
hparams = image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
|
@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
validation_data=self.test_data)
|
||||
self._test_accuracy(model)
|
||||
|
||||
# Test export_model
|
||||
model.export_model()
|
||||
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
|
||||
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
|
||||
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
|
||||
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.0):
|
||||
_, accuracy = model.evaluate(self.test_data)
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
|
23
mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD
vendored
Normal file
23
mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["metadata.json"],
|
||||
)
|
68
mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json
vendored
Normal file
68
mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json
vendored
Normal file
|
@ -0,0 +1,68 @@
|
|||
{
|
||||
"name": "ImageClassifier",
|
||||
"description": "Identify the most prominent object in the image from a known set of categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
0.0
|
||||
],
|
||||
"std": [
|
||||
255.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
|
@ -24,7 +24,7 @@ py_library(
|
|||
srcs = ["test_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = [
|
||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
|
||||
"//mediapipe/model_maker/python:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
],
|
||||
deps = [
|
||||
|
|
Loading…
Reference in New Issue
Block a user