Add metadata writer to image_classifier in model_maker

PiperOrigin-RevId: 485926985
This commit is contained in:
MediaPipe Team 2022-11-03 11:28:56 -07:00 committed by Copybara-Service
parent b472d8ff66
commit d29c3d7512
10 changed files with 225 additions and 54 deletions

View File

@ -52,6 +52,7 @@ class CustomModel(abc.ABC):
"""Prints a summary of the model.""" """Prints a summary of the model."""
self._model.summary() self._model.summary()
# TODO: Remove this method when all tasks use Metadata writer
def export_tflite( def export_tflite(
self, self,
export_dir: str, export_dir: str,
@ -62,7 +63,7 @@ class CustomModel(abc.ABC):
Args: Args:
export_dir: The directory to save exported files. export_dir: The directory to save exported files.
tflite_filename: File name to save tflite model. The full export path is tflite_filename: File name to save TFLite model. The full export path is
{export_dir}/{tflite_filename}. {export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
@ -73,11 +74,11 @@ class CustomModel(abc.ABC):
tf.io.gfile.makedirs(export_dir) tf.io.gfile.makedirs(export_dir)
tflite_filepath = os.path.join(export_dir, tflite_filename) tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model. tflite_model = model_util.convert_to_tflite(
model_util.export_tflite(
model=self._model, model=self._model,
tflite_filepath=tflite_filepath,
quantization_config=quantization_config, quantization_config=quantization_config,
preprocess=preprocess) preprocess=preprocess)
model_util.save_tflite(
tflite_model=tflite_model, tflite_file=tflite_filepath)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath) 'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
return len(train_data) // batch_size return len(train_data) // batch_size
def export_tflite( def convert_to_tflite(
model: tf.keras.Model, model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None): preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
"""Converts the model to tflite format and saves it. """Converts the input Keras model to TFLite format.
Args: Args:
model: model to be converted to tflite. model: Keras model to be converted to TFLite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization. quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file. supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label, quantization. The callable takes three arguments in order: feature, label,
and is_training. and is_training.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
Returns:
bytearray of TFLite model
"""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model') save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf') model.save(save_path, include_optimizer=False, save_format='tf')
@ -122,9 +119,22 @@ def export_tflite(
converter.target_spec.supported_ops = supported_ops converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert() tflite_model = converter.convert()
return tflite_model
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
"""Saves TFLite file to tflite_file.
Args:
tflite_model: A valid flatbuffer representing the TFLite model.
tflite_file: File path to save TFLite model.
"""
if tflite_file is None:
raise ValueError("TFLite filepath can't be None when exporting to TFLite.")
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
f.write(tflite_model) f.write(tflite_model)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
class LiteRunner(object): class LiteRunner(object):
"""A runner to do inference with the TFLite model.""" """A runner to do inference with the TFLite model."""
def __init__(self, tflite_filepath: str): def __init__(self, tflite_model: bytearray):
"""Initializes Lite runner with tflite model file. """Initializes Lite runner from TFLite model buffer.
Args: Args:
tflite_filepath: File path to the TFLite model. tflite_model: A valid flatbuffer representing the TFLite model.
""" """
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model) self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors() self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details() self.input_details = self.interpreter.get_input_details()
@ -250,9 +258,9 @@ class LiteRunner(object):
return output_tensors return output_tensors
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner': def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
"""Returns a `LiteRunner` from file path to TFLite model.""" """Returns a `LiteRunner` from flatbuffer of the TFLite model."""
lite_runner = LiteRunner(tflite_filepath) lite_runner = LiteRunner(tflite_buffer)
return lite_runner return lite_runner

View File

@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
'name': 'test' 'name': 'test'
}) })
def test_export_tflite(self): def test_convert_to_tflite(self):
input_dim = 4 input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_model = model_util.convert_to_tflite(model)
model_util.export_tflite(model, tflite_file)
test_util.test_tflite( test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) keras_model=model, tflite_model=tflite_model, size=[1, input_dim])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
testcase_name='float16_quantize', testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(), config=quantization.QuantizationConfig.for_float16(),
model_size=1468)) model_size=1468))
def test_export_tflite_quantized(self, config, model_size): def test_convert_to_tflite_quantized(self, config, model_size):
input_dim = 16 input_dim = 16
num_classes = 2 num_classes = 2
max_input_value = 5 max_input_value = 5
model = test_util.build_model( model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes) input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite( tflite_model = model_util.convert_to_tflite(
model=model, tflite_filepath=tflite_file, quantization_config=config) model=model, quantization_config=config)
self.assertTrue( self.assertTrue(
test_util.test_tflite( test_util.test_tflite(
keras_model=model, keras_model=model,
tflite_file=tflite_file, tflite_model=tflite_model,
size=[1, input_dim], size=[1, input_dim],
high=max_input_value, high=max_input_value,
atol=1e-00)) atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300) self.assertNear(len(tflite_model), model_size, 300)
def test_save_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
test_util.test_tflite_file(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View File

@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
return model return model
def is_same_output(tflite_file: str, def is_same_output(tflite_model: bytearray,
keras_model: tf.keras.Model, keras_model: tf.keras.Model,
input_tensors: Union[List[tf.Tensor], tf.Tensor], input_tensors: Union[List[tf.Tensor], tf.Tensor],
atol: float = 1e-04) -> bool: atol: float = 1e-04) -> bool:
"""Returns if the output of TFLite model and keras model are identical.""" """Returns if the output of TFLite model and keras model are identical."""
# Gets output from lite model. # Gets output from lite model.
lite_runner = model_util.get_lite_runner(tflite_file) lite_runner = model_util.get_lite_runner(tflite_model)
lite_output = lite_runner.run(input_tensors) lite_output = lite_runner.run(input_tensors)
# Gets output from keras model. # Gets output from keras model.
@ -95,12 +95,41 @@ def is_same_output(tflite_file: str,
def test_tflite(keras_model: tf.keras.Model, def test_tflite(keras_model: tf.keras.Model,
tflite_file: str, tflite_model: bytearray,
size: Union[int, List[int]], size: Union[int, List[int]],
high: float = 1, high: float = 1,
atol: float = 1e-04) -> bool: atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical. """Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_model: Input TFLite model flatbuffer.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_model=tflite_model,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)
def test_tflite_file(keras_model: tf.keras.Model,
tflite_file: bytearray,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args: Args:
keras_model: Input TensorFlow Keras model. keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file. tflite_file: Input TFLite model file.
@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model,
True if the output of TFLite model and TF Keras model are identical. True if the output of TFLite model and TF Keras model are identical.
Otherwise, False. Otherwise, False.
""" """
random_input = create_random_sample(size=size, high=high) with tf.io.gfile.GFile(tflite_file, "rb") as f:
random_input = tf.convert_to_tensor(random_input) tflite_model = f.read()
return test_tflite(keras_model, tflite_model, size, high, atol)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)

View File

@ -81,6 +81,8 @@ py_library(
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing", "//mediapipe/model_maker/python/vision/core:image_preprocessing",
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
], ],
) )
@ -88,7 +90,11 @@ py_library(
name = "image_classifier_test_lib", name = "image_classifier_test_lib",
testonly = 1, testonly = 1,
srcs = ["image_classifier_test.py"], srcs = ["image_classifier_test.py"],
deps = [":image_classifier_import"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
deps = [
":image_classifier_import",
"//mediapipe/tasks/python/test:test_utils",
],
) )
py_test( py_test(

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""APIs to train image classifier model.""" """APIs to train image classifier model."""
import os
from typing import List, Optional from typing import List, Optional
@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
class ImageClassifier(classifier.Classifier): class ImageClassifier(classifier.Classifier):
@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier):
self, self,
model_name: str = 'model.tflite', model_name: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None): quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts the model to the requested formats and exports to a file. """Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
Args: Args:
model_name: File name to save tflite model. The full export path is model_name: File name to save TFLite model with metadata. The full export
{export_dir}/{tflite_filename}. path is {self._hparams.model_dir}/{model_name}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
""" """
super().export_tflite( if not tf.io.gfile.exists(self._hparams.model_dir):
self._hparams.model_dir, tf.io.gfile.makedirs(self._hparams.model_dir)
model_name, tflite_file = os.path.join(self._hparams.model_dir, model_name)
quantization_config, metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
tflite_model = model_util.convert_to_tflite(
model=self._model,
quantization_config=quantization_config,
preprocess=self._preprocess) preprocess=self._preprocess)
writer = image_classifier_writer.MetadataWriter.create(
tflite_model,
self._model_spec.mean_rgb,
self._model_spec.stddev_rgb,
labels=metadata_writer.Labels().add(self._label_names))
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import filecmp
import os import os
from absl.testing import parameterized from absl.testing import parameterized
@ -19,6 +20,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.tasks.python.test import test_utils
def _fill_image(rgb, image_size): def _fill_image(rgb, image_size):
@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data) validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self): def test_efficientnetlite0_model_train_and_export(self):
hparams = image_classifier.HParams( hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True) train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create( model = image_classifier.ImageClassifier.create(
@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data) validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
# Test export_model
model.export_model()
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
self.assertTrue(os.path.exists(output_tflite_file))
self.assertGreater(os.path.getsize(output_tflite_file), 0)
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def _test_accuracy(self, model, threshold=0.0): def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data) _, accuracy = model.evaluate(self.test_data)
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["metadata.json"],
)

View File

@ -0,0 +1,68 @@
{
"name": "ImageClassifier",
"description": "Identify the most prominent object in the image from a known set of categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0
],
"std": [
255.0
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -24,7 +24,7 @@ py_library(
srcs = ["test_utils.py"], srcs = ["test_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
visibility = [ visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", "//mediapipe/model_maker/python:__subpackages__",
"//mediapipe/tasks:internal", "//mediapipe/tasks:internal",
], ],
deps = [ deps = [