Internal change
PiperOrigin-RevId: 489752009
This commit is contained in:
parent
a33cb1e05e
commit
bdf4078e89
|
@ -45,6 +45,7 @@ py_test(
|
||||||
name = "model_util_test",
|
name = "model_util_test",
|
||||||
srcs = ["model_util_test.py"],
|
srcs = ["model_util_test.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":file_util",
|
||||||
":model_util",
|
":model_util",
|
||||||
":quantization",
|
":quantization",
|
||||||
":test_util",
|
":test_util",
|
||||||
|
|
|
@ -14,10 +14,12 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.core.utils import test_util
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
@ -25,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_load_keras_model(self):
|
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
||||||
|
def test_load_keras_model(self, mock_get_absolute_path):
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||||
model.save(saved_model_path)
|
model.save(saved_model_path)
|
||||||
|
# model_util.load_keras_model takes in a relative path to files within the
|
||||||
|
# model_maker dir, so we patch the function for testing
|
||||||
|
mock_get_absolute_path.return_value = saved_model_path
|
||||||
loaded_model = model_util.load_keras_model(saved_model_path)
|
loaded_model = model_util.load_keras_model(saved_model_path)
|
||||||
|
|
||||||
input_tensors = test_util.create_random_sample(size=[1, input_dim])
|
input_tensors = test_util.create_random_sample(size=[1, input_dim])
|
||||||
|
@ -37,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||||
self.assertTrue((model_output == loaded_model_output).all())
|
self.assertTrue((model_output == loaded_model_output).all())
|
||||||
|
|
||||||
def test_load_tflite_model_buffer(self):
|
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
||||||
|
def test_load_tflite_model_buffer(self, mock_get_absolute_path):
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
tflite_model = model_util.convert_to_tflite(model)
|
tflite_model = model_util.convert_to_tflite(model)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||||
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||||
|
# model_util.load_tflite_model_buffer takes in a relative path to files
|
||||||
|
# within the model_maker dir, so we patch the function for testing
|
||||||
|
mock_get_absolute_path.return_value = tflite_file
|
||||||
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
||||||
test_util.test_tflite(
|
test_util.test_tflite(
|
||||||
keras_model=model,
|
keras_model=model,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user