Internal change

PiperOrigin-RevId: 489752009
This commit is contained in:
MediaPipe Team 2022-11-19 21:12:23 -08:00 committed by Copybara-Service
parent a33cb1e05e
commit bdf4078e89
2 changed files with 13 additions and 3 deletions

View File

@ -45,6 +45,7 @@ py_test(
name = "model_util_test",
srcs = ["model_util_test.py"],
deps = [
":file_util",
":model_util",
":quantization",
":test_util",

View File

@ -14,10 +14,12 @@
import os
from typing import Optional
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
@ -25,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
def test_load_keras_model(self):
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
def test_load_keras_model(self, mock_get_absolute_path):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
model.save(saved_model_path)
# model_util.load_keras_model takes in a relative path to files within the
# model_maker dir, so we patch the function for testing
mock_get_absolute_path.return_value = saved_model_path
loaded_model = model_util.load_keras_model(saved_model_path)
input_tensors = test_util.create_random_sample(size=[1, input_dim])
@ -37,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
self.assertTrue((model_output == loaded_model_output).all())
def test_load_tflite_model_buffer(self):
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
def test_load_tflite_model_buffer(self, mock_get_absolute_path):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
# model_util.load_tflite_model_buffer takes in a relative path to files
# within the model_maker dir, so we patch the function for testing
mock_get_absolute_path.return_value = tflite_file
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
test_util.test_tflite(
keras_model=model,