Add metadata writer to image_classifier in model_maker
PiperOrigin-RevId: 485926985
This commit is contained in:
parent
b472d8ff66
commit
d29c3d7512
|
@ -52,6 +52,7 @@ class CustomModel(abc.ABC):
|
||||||
"""Prints a summary of the model."""
|
"""Prints a summary of the model."""
|
||||||
self._model.summary()
|
self._model.summary()
|
||||||
|
|
||||||
|
# TODO: Remove this method when all tasks use Metadata writer
|
||||||
def export_tflite(
|
def export_tflite(
|
||||||
self,
|
self,
|
||||||
export_dir: str,
|
export_dir: str,
|
||||||
|
@ -62,7 +63,7 @@ class CustomModel(abc.ABC):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
export_dir: The directory to save exported files.
|
export_dir: The directory to save exported files.
|
||||||
tflite_filename: File name to save tflite model. The full export path is
|
tflite_filename: File name to save TFLite model. The full export path is
|
||||||
{export_dir}/{tflite_filename}.
|
{export_dir}/{tflite_filename}.
|
||||||
quantization_config: The configuration for model quantization.
|
quantization_config: The configuration for model quantization.
|
||||||
preprocess: A callable to preprocess the representative dataset for
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
|
@ -73,11 +74,11 @@ class CustomModel(abc.ABC):
|
||||||
tf.io.gfile.makedirs(export_dir)
|
tf.io.gfile.makedirs(export_dir)
|
||||||
|
|
||||||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||||
# TODO: Populate metadata to the exported TFLite model.
|
tflite_model = model_util.convert_to_tflite(
|
||||||
model_util.export_tflite(
|
|
||||||
model=self._model,
|
model=self._model,
|
||||||
tflite_filepath=tflite_filepath,
|
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
preprocess=preprocess)
|
preprocess=preprocess)
|
||||||
|
model_util.save_tflite(
|
||||||
|
tflite_model=tflite_model, tflite_file=tflite_filepath)
|
||||||
tf.compat.v1.logging.info(
|
tf.compat.v1.logging.info(
|
||||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
||||||
|
|
|
@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
return len(train_data) // batch_size
|
return len(train_data) // batch_size
|
||||||
|
|
||||||
|
|
||||||
def export_tflite(
|
def convert_to_tflite(
|
||||||
model: tf.keras.Model,
|
model: tf.keras.Model,
|
||||||
tflite_filepath: str,
|
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
supported_ops: Tuple[tf.lite.OpsSet,
|
supported_ops: Tuple[tf.lite.OpsSet,
|
||||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
||||||
preprocess: Optional[Callable[..., bool]] = None):
|
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
|
||||||
"""Converts the model to tflite format and saves it.
|
"""Converts the input Keras model to TFLite format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: model to be converted to tflite.
|
model: Keras model to be converted to TFLite.
|
||||||
tflite_filepath: File path to save tflite model.
|
|
||||||
quantization_config: Configuration for post-training quantization.
|
quantization_config: Configuration for post-training quantization.
|
||||||
supported_ops: A list of supported ops in the converted TFLite file.
|
supported_ops: A list of supported ops in the converted TFLite file.
|
||||||
preprocess: A callable to preprocess the representative dataset for
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
quantization. The callable takes three arguments in order: feature, label,
|
quantization. The callable takes three arguments in order: feature, label,
|
||||||
and is_training.
|
and is_training.
|
||||||
"""
|
|
||||||
if tflite_filepath is None:
|
|
||||||
raise ValueError(
|
|
||||||
"TFLite filepath couldn't be None when exporting to tflite.")
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytearray of TFLite model
|
||||||
|
"""
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
save_path = os.path.join(temp_dir, 'saved_model')
|
save_path = os.path.join(temp_dir, 'saved_model')
|
||||||
model.save(save_path, include_optimizer=False, save_format='tf')
|
model.save(save_path, include_optimizer=False, save_format='tf')
|
||||||
|
@ -122,9 +119,22 @@ def export_tflite(
|
||||||
|
|
||||||
converter.target_spec.supported_ops = supported_ops
|
converter.target_spec.supported_ops = supported_ops
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
return tflite_model
|
||||||
|
|
||||||
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
|
|
||||||
|
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
|
||||||
|
"""Saves TFLite file to tflite_file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_model: A valid flatbuffer representing the TFLite model.
|
||||||
|
tflite_file: File path to save TFLite model.
|
||||||
|
"""
|
||||||
|
if tflite_file is None:
|
||||||
|
raise ValueError("TFLite filepath can't be None when exporting to TFLite.")
|
||||||
|
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
|
tf.compat.v1.logging.info(
|
||||||
|
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
|
||||||
|
|
||||||
|
|
||||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
|
@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
class LiteRunner(object):
|
class LiteRunner(object):
|
||||||
"""A runner to do inference with the TFLite model."""
|
"""A runner to do inference with the TFLite model."""
|
||||||
|
|
||||||
def __init__(self, tflite_filepath: str):
|
def __init__(self, tflite_model: bytearray):
|
||||||
"""Initializes Lite runner with tflite model file.
|
"""Initializes Lite runner from TFLite model buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tflite_filepath: File path to the TFLite model.
|
tflite_model: A valid flatbuffer representing the TFLite model.
|
||||||
"""
|
"""
|
||||||
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
|
|
||||||
tflite_model = f.read()
|
|
||||||
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||||
self.interpreter.allocate_tensors()
|
self.interpreter.allocate_tensors()
|
||||||
self.input_details = self.interpreter.get_input_details()
|
self.input_details = self.interpreter.get_input_details()
|
||||||
|
@ -250,9 +258,9 @@ class LiteRunner(object):
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
|
|
||||||
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
|
def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
|
||||||
"""Returns a `LiteRunner` from file path to TFLite model."""
|
"""Returns a `LiteRunner` from flatbuffer of the TFLite model."""
|
||||||
lite_runner = LiteRunner(tflite_filepath)
|
lite_runner = LiteRunner(tflite_buffer)
|
||||||
return lite_runner
|
return lite_runner
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
'name': 'test'
|
'name': 'test'
|
||||||
})
|
})
|
||||||
|
|
||||||
def test_export_tflite(self):
|
def test_convert_to_tflite(self):
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
tflite_model = model_util.convert_to_tflite(model)
|
||||||
model_util.export_tflite(model, tflite_file)
|
|
||||||
test_util.test_tflite(
|
test_util.test_tflite(
|
||||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
keras_model=model, tflite_model=tflite_model, size=[1, input_dim])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(
|
dict(
|
||||||
|
@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
testcase_name='float16_quantize',
|
testcase_name='float16_quantize',
|
||||||
config=quantization.QuantizationConfig.for_float16(),
|
config=quantization.QuantizationConfig.for_float16(),
|
||||||
model_size=1468))
|
model_size=1468))
|
||||||
def test_export_tflite_quantized(self, config, model_size):
|
def test_convert_to_tflite_quantized(self, config, model_size):
|
||||||
input_dim = 16
|
input_dim = 16
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
max_input_value = 5
|
max_input_value = 5
|
||||||
model = test_util.build_model(
|
model = test_util.build_model(
|
||||||
input_shape=[input_dim], num_classes=num_classes)
|
input_shape=[input_dim], num_classes=num_classes)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
|
||||||
|
|
||||||
model_util.export_tflite(
|
tflite_model = model_util.convert_to_tflite(
|
||||||
model=model, tflite_filepath=tflite_file, quantization_config=config)
|
model=model, quantization_config=config)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
test_util.test_tflite(
|
test_util.test_tflite(
|
||||||
keras_model=model,
|
keras_model=model,
|
||||||
tflite_file=tflite_file,
|
tflite_model=tflite_model,
|
||||||
size=[1, input_dim],
|
size=[1, input_dim],
|
||||||
high=max_input_value,
|
high=max_input_value,
|
||||||
atol=1e-00))
|
atol=1e-00))
|
||||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
self.assertNear(len(tflite_model), model_size, 300)
|
||||||
|
|
||||||
|
def test_save_tflite(self):
|
||||||
|
input_dim = 4
|
||||||
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
|
tflite_model = model_util.convert_to_tflite(model)
|
||||||
|
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||||
|
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||||
|
test_util.test_tflite_file(
|
||||||
|
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def is_same_output(tflite_file: str,
|
def is_same_output(tflite_model: bytearray,
|
||||||
keras_model: tf.keras.Model,
|
keras_model: tf.keras.Model,
|
||||||
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
||||||
atol: float = 1e-04) -> bool:
|
atol: float = 1e-04) -> bool:
|
||||||
"""Returns if the output of TFLite model and keras model are identical."""
|
"""Returns if the output of TFLite model and keras model are identical."""
|
||||||
# Gets output from lite model.
|
# Gets output from lite model.
|
||||||
lite_runner = model_util.get_lite_runner(tflite_file)
|
lite_runner = model_util.get_lite_runner(tflite_model)
|
||||||
lite_output = lite_runner.run(input_tensors)
|
lite_output = lite_runner.run(input_tensors)
|
||||||
|
|
||||||
# Gets output from keras model.
|
# Gets output from keras model.
|
||||||
|
@ -95,7 +95,36 @@ def is_same_output(tflite_file: str,
|
||||||
|
|
||||||
|
|
||||||
def test_tflite(keras_model: tf.keras.Model,
|
def test_tflite(keras_model: tf.keras.Model,
|
||||||
tflite_file: str,
|
tflite_model: bytearray,
|
||||||
|
size: Union[int, List[int]],
|
||||||
|
high: float = 1,
|
||||||
|
atol: float = 1e-04) -> bool:
|
||||||
|
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keras_model: Input TensorFlow Keras model.
|
||||||
|
tflite_model: Input TFLite model flatbuffer.
|
||||||
|
size: Size of the input tesnor.
|
||||||
|
high: Higher boundary of the values in input tensors.
|
||||||
|
atol: Absolute tolerance of the difference between the outputs of Keras
|
||||||
|
model and TFLite model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the output of TFLite model and TF Keras model are identical.
|
||||||
|
Otherwise, False.
|
||||||
|
"""
|
||||||
|
random_input = create_random_sample(size=size, high=high)
|
||||||
|
random_input = tf.convert_to_tensor(random_input)
|
||||||
|
|
||||||
|
return is_same_output(
|
||||||
|
tflite_model=tflite_model,
|
||||||
|
keras_model=keras_model,
|
||||||
|
input_tensors=random_input,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tflite_file(keras_model: tf.keras.Model,
|
||||||
|
tflite_file: bytearray,
|
||||||
size: Union[int, List[int]],
|
size: Union[int, List[int]],
|
||||||
high: float = 1,
|
high: float = 1,
|
||||||
atol: float = 1e-04) -> bool:
|
atol: float = 1e-04) -> bool:
|
||||||
|
@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model,
|
||||||
True if the output of TFLite model and TF Keras model are identical.
|
True if the output of TFLite model and TF Keras model are identical.
|
||||||
Otherwise, False.
|
Otherwise, False.
|
||||||
"""
|
"""
|
||||||
random_input = create_random_sample(size=size, high=high)
|
with tf.io.gfile.GFile(tflite_file, "rb") as f:
|
||||||
random_input = tf.convert_to_tensor(random_input)
|
tflite_model = f.read()
|
||||||
|
return test_tflite(keras_model, tflite_model, size, high, atol)
|
||||||
return is_same_output(
|
|
||||||
tflite_file=tflite_file,
|
|
||||||
keras_model=keras_model,
|
|
||||||
input_tensors=random_input,
|
|
||||||
atol=atol)
|
|
||||||
|
|
|
@ -81,6 +81,8 @@ py_library(
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,7 +90,11 @@ py_library(
|
||||||
name = "image_classifier_test_lib",
|
name = "image_classifier_test_lib",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
srcs = ["image_classifier_test.py"],
|
srcs = ["image_classifier_test.py"],
|
||||||
deps = [":image_classifier_import"],
|
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
|
||||||
|
deps = [
|
||||||
|
":image_classifier_import",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""APIs to train image classifier model."""
|
"""APIs to train image classifier model."""
|
||||||
|
import os
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
|
||||||
|
|
||||||
class ImageClassifier(classifier.Classifier):
|
class ImageClassifier(classifier.Classifier):
|
||||||
|
@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier):
|
||||||
self,
|
self,
|
||||||
model_name: str = 'model.tflite',
|
model_name: str = 'model.tflite',
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||||
"""Converts the model to the requested formats and exports to a file.
|
"""Converts and saves the model to a TFLite file with metadata included.
|
||||||
|
|
||||||
|
Note that only the TFLite file is needed for deployment. This function also
|
||||||
|
saves a metadata.json file to the same directory as the TFLite file which
|
||||||
|
can be used to interpret the metadata content in the TFLite file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: File name to save tflite model. The full export path is
|
model_name: File name to save TFLite model with metadata. The full export
|
||||||
{export_dir}/{tflite_filename}.
|
path is {self._hparams.model_dir}/{model_name}.
|
||||||
quantization_config: The configuration for model quantization.
|
quantization_config: The configuration for model quantization.
|
||||||
"""
|
"""
|
||||||
super().export_tflite(
|
if not tf.io.gfile.exists(self._hparams.model_dir):
|
||||||
self._hparams.model_dir,
|
tf.io.gfile.makedirs(self._hparams.model_dir)
|
||||||
model_name,
|
tflite_file = os.path.join(self._hparams.model_dir, model_name)
|
||||||
quantization_config,
|
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
|
||||||
|
|
||||||
|
tflite_model = model_util.convert_to_tflite(
|
||||||
|
model=self._model,
|
||||||
|
quantization_config=quantization_config,
|
||||||
preprocess=self._preprocess)
|
preprocess=self._preprocess)
|
||||||
|
writer = image_classifier_writer.MetadataWriter.create(
|
||||||
|
tflite_model,
|
||||||
|
self._model_spec.mean_rgb,
|
||||||
|
self._model_spec.stddev_rgb,
|
||||||
|
labels=metadata_writer.Labels().add(self._label_names))
|
||||||
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
f.write(metadata_json)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import filecmp
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
@ -19,6 +20,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision import image_classifier
|
from mediapipe.model_maker.python.vision import image_classifier
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
def _fill_image(rgb, image_size):
|
def _fill_image(rgb, image_size):
|
||||||
|
@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
validation_data=self.test_data)
|
validation_data=self.test_data)
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
|
def test_efficientnetlite0_model_train_and_export(self):
|
||||||
hparams = image_classifier.HParams(
|
hparams = image_classifier.HParams(
|
||||||
train_epochs=1, batch_size=1, shuffle=True)
|
train_epochs=1, batch_size=1, shuffle=True)
|
||||||
model = image_classifier.ImageClassifier.create(
|
model = image_classifier.ImageClassifier.create(
|
||||||
|
@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
validation_data=self.test_data)
|
validation_data=self.test_data)
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
# Test export_model
|
||||||
|
model.export_model()
|
||||||
|
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
|
||||||
|
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
|
||||||
|
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||||
|
|
||||||
def _test_accuracy(self, model, threshold=0.0):
|
def _test_accuracy(self, model, threshold=0.0):
|
||||||
_, accuracy = model.evaluate(self.test_data)
|
_, accuracy = model.evaluate(self.test_data)
|
||||||
self.assertGreaterEqual(accuracy, threshold)
|
self.assertGreaterEqual(accuracy, threshold)
|
||||||
|
|
23
mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD
vendored
Normal file
23
mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = ["metadata.json"],
|
||||||
|
)
|
68
mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json
vendored
Normal file
68
mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json
vendored
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
{
|
||||||
|
"name": "ImageClassifier",
|
||||||
|
"description": "Identify the most prominent object in the image from a known set of categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "image",
|
||||||
|
"description": "Input image to be processed.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "ImageProperties",
|
||||||
|
"content_properties": {
|
||||||
|
"color_space": "RGB"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "NormalizationOptions",
|
||||||
|
"options": {
|
||||||
|
"mean": [
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"std": [
|
||||||
|
255.0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.0.0"
|
||||||
|
}
|
|
@ -24,7 +24,7 @@ py_library(
|
||||||
srcs = ["test_utils.py"],
|
srcs = ["test_utils.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
visibility = [
|
visibility = [
|
||||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
|
"//mediapipe/model_maker/python:__subpackages__",
|
||||||
"//mediapipe/tasks:internal",
|
"//mediapipe/tasks:internal",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
|
Loading…
Reference in New Issue
Block a user