Added some sanity tests

This commit is contained in:
kinaryml 2022-11-08 23:15:58 -08:00
parent 36c50ff8f3
commit dd6fdedd5f

View File

@ -69,6 +69,34 @@ class ImageEmbedderTest(parameterized.TestCase):
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _ImageEmbedder.create_from_model_path(self.model_path) as embedder:
self.assertIsInstance(embedder, _ImageEmbedder)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageEmbedderOptions(base_options=base_options)
with _ImageEmbedder.create_from_options(options) as embedder:
self.assertIsInstance(embedder, _ImageEmbedder)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
options = _ImageEmbedderOptions(base_options=base_options)
_ImageEmbedder.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _ImageEmbedderOptions(base_options=base_options)
embedder = _ImageEmbedder.create_from_options(options)
self.assertIsInstance(embedder, _ImageEmbedder)
def _check_cosine_similarity(self, result0, result1, quantize,
expected_similarity):
# Checks head_index and head_name.