Internal change
PiperOrigin-RevId: 489752009
This commit is contained in:
parent
a33cb1e05e
commit
bdf4078e89
|
@ -45,6 +45,7 @@ py_test(
|
|||
name = "model_util_test",
|
||||
srcs = ["model_util_test.py"],
|
||||
deps = [
|
||||
":file_util",
|
||||
":model_util",
|
||||
":quantization",
|
||||
":test_util",
|
||||
|
|
|
@ -14,10 +14,12 @@
|
|||
|
||||
import os
|
||||
from typing import Optional
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
@ -25,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util
|
|||
|
||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_load_keras_model(self):
|
||||
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
||||
def test_load_keras_model(self, mock_get_absolute_path):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
model.save(saved_model_path)
|
||||
# model_util.load_keras_model takes in a relative path to files within the
|
||||
# model_maker dir, so we patch the function for testing
|
||||
mock_get_absolute_path.return_value = saved_model_path
|
||||
loaded_model = model_util.load_keras_model(saved_model_path)
|
||||
|
||||
input_tensors = test_util.create_random_sample(size=[1, input_dim])
|
||||
|
@ -37,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||
self.assertTrue((model_output == loaded_model_output).all())
|
||||
|
||||
def test_load_tflite_model_buffer(self):
|
||||
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
||||
def test_load_tflite_model_buffer(self, mock_get_absolute_path):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_model = model_util.convert_to_tflite(model)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||
|
||||
# model_util.load_tflite_model_buffer takes in a relative path to files
|
||||
# within the model_maker dir, so we patch the function for testing
|
||||
mock_get_absolute_path.return_value = tflite_file
|
||||
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
|
|
Loading…
Reference in New Issue
Block a user