Add metadata writer to image_classifier in model_maker

PiperOrigin-RevId: 485926985
This commit is contained in:
MediaPipe Team 2022-11-03 11:28:56 -07:00 committed by Copybara-Service
parent b472d8ff66
commit d29c3d7512
10 changed files with 225 additions and 54 deletions

View File

@ -52,6 +52,7 @@ class CustomModel(abc.ABC):
"""Prints a summary of the model."""
self._model.summary()
# TODO: Remove this method when all tasks use Metadata writer
def export_tflite(
self,
export_dir: str,
@ -62,7 +63,7 @@ class CustomModel(abc.ABC):
Args:
export_dir: The directory to save exported files.
tflite_filename: File name to save tflite model. The full export path is
tflite_filename: File name to save TFLite model. The full export path is
{export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization.
preprocess: A callable to preprocess the representative dataset for
@ -73,11 +74,11 @@ class CustomModel(abc.ABC):
tf.io.gfile.makedirs(export_dir)
tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model.
model_util.export_tflite(
tflite_model = model_util.convert_to_tflite(
model=self._model,
tflite_filepath=tflite_filepath,
quantization_config=quantization_config,
preprocess=preprocess)
model_util.save_tflite(
tflite_model=tflite_model, tflite_file=tflite_filepath)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
return len(train_data) // batch_size
def export_tflite(
def convert_to_tflite(
model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None):
"""Converts the model to tflite format and saves it.
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
"""Converts the input Keras model to TFLite format.
Args:
model: model to be converted to tflite.
tflite_filepath: File path to save tflite model.
model: Keras model to be converted to TFLite.
quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label,
and is_training.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
Returns:
bytearray of TFLite model
"""
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf')
@ -122,9 +119,22 @@ def export_tflite(
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()
return tflite_model
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
"""Saves TFLite file to tflite_file.
Args:
tflite_model: A valid flatbuffer representing the TFLite model.
tflite_file: File path to save TFLite model.
"""
if tflite_file is None:
raise ValueError("TFLite filepath can't be None when exporting to TFLite.")
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
f.write(tflite_model)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
class LiteRunner(object):
"""A runner to do inference with the TFLite model."""
def __init__(self, tflite_filepath: str):
"""Initializes Lite runner with tflite model file.
def __init__(self, tflite_model: bytearray):
"""Initializes Lite runner from TFLite model buffer.
Args:
tflite_filepath: File path to the TFLite model.
tflite_model: A valid flatbuffer representing the TFLite model.
"""
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
@ -250,9 +258,9 @@ class LiteRunner(object):
return output_tensors
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
"""Returns a `LiteRunner` from file path to TFLite model."""
lite_runner = LiteRunner(tflite_filepath)
def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
"""Returns a `LiteRunner` from flatbuffer of the TFLite model."""
lite_runner = LiteRunner(tflite_buffer)
return lite_runner

View File

@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
'name': 'test'
})
def test_export_tflite(self):
def test_convert_to_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file)
tflite_model = model_util.convert_to_tflite(model)
test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
keras_model=model, tflite_model=tflite_model, size=[1, input_dim])
@parameterized.named_parameters(
dict(
@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(),
model_size=1468))
def test_export_tflite_quantized(self, config, model_size):
def test_convert_to_tflite_quantized(self, config, model_size):
input_dim = 16
num_classes = 2
max_input_value = 5
model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(
model=model, tflite_filepath=tflite_file, quantization_config=config)
tflite_model = model_util.convert_to_tflite(
model=model, quantization_config=config)
self.assertTrue(
test_util.test_tflite(
keras_model=model,
tflite_file=tflite_file,
tflite_model=tflite_model,
size=[1, input_dim],
high=max_input_value,
atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
self.assertNear(len(tflite_model), model_size, 300)
def test_save_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
test_util.test_tflite_file(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
if __name__ == '__main__':
tf.test.main()

View File

@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
return model
def is_same_output(tflite_file: str,
def is_same_output(tflite_model: bytearray,
keras_model: tf.keras.Model,
input_tensors: Union[List[tf.Tensor], tf.Tensor],
atol: float = 1e-04) -> bool:
"""Returns if the output of TFLite model and keras model are identical."""
# Gets output from lite model.
lite_runner = model_util.get_lite_runner(tflite_file)
lite_runner = model_util.get_lite_runner(tflite_model)
lite_output = lite_runner.run(input_tensors)
# Gets output from keras model.
@ -95,12 +95,41 @@ def is_same_output(tflite_file: str,
def test_tflite(keras_model: tf.keras.Model,
tflite_file: str,
tflite_model: bytearray,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_model: Input TFLite model flatbuffer.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_model=tflite_model,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)
def test_tflite_file(keras_model: tf.keras.Model,
tflite_file: bytearray,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file.
@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model,
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)
with tf.io.gfile.GFile(tflite_file, "rb") as f:
tflite_model = f.read()
return test_tflite(keras_model, tflite_model, size, high, atol)

View File

@ -81,6 +81,8 @@ py_library(
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
],
)
@ -88,7 +90,11 @@ py_library(
name = "image_classifier_test_lib",
testonly = 1,
srcs = ["image_classifier_test.py"],
deps = [":image_classifier_import"],
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
deps = [
":image_classifier_import",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_test(

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train image classifier model."""
import os
from typing import List, Optional
@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
class ImageClassifier(classifier.Classifier):
@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier):
self,
model_name: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts the model to the requested formats and exports to a file.
"""Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
Args:
model_name: File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
model_name: File name to save TFLite model with metadata. The full export
path is {self._hparams.model_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
super().export_tflite(
self._hparams.model_dir,
model_name,
quantization_config,
if not tf.io.gfile.exists(self._hparams.model_dir):
tf.io.gfile.makedirs(self._hparams.model_dir)
tflite_file = os.path.join(self._hparams.model_dir, model_name)
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
tflite_model = model_util.convert_to_tflite(
model=self._model,
quantization_config=quantization_config,
preprocess=self._preprocess)
writer = image_classifier_writer.MetadataWriter.create(
tflite_model,
self._model_spec.mean_rgb,
self._model_spec.stddev_rgb,
labels=metadata_writer.Labels().add(self._label_names))
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import filecmp
import os
from absl.testing import parameterized
@ -19,6 +20,7 @@ import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.tasks.python.test import test_utils
def _fill_image(rgb, image_size):
@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data)
self._test_accuracy(model)
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
def test_efficientnetlite0_model_train_and_export(self):
hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create(
@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data)
self._test_accuracy(model)
# Test export_model
model.export_model()
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
self.assertTrue(os.path.exists(output_tflite_file))
self.assertGreater(os.path.getsize(output_tflite_file), 0)
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data)
self.assertGreaterEqual(accuracy, threshold)

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["metadata.json"],
)

View File

@ -0,0 +1,68 @@
{
"name": "ImageClassifier",
"description": "Identify the most prominent object in the image from a known set of categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0
],
"std": [
255.0
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -24,7 +24,7 @@ py_library(
srcs = ["test_utils.py"],
srcs_version = "PY3",
visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
"//mediapipe/model_maker/python:__subpackages__",
"//mediapipe/tasks:internal",
],
deps = [