Added some sanity tests
This commit is contained in:
parent
36c50ff8f3
commit
dd6fdedd5f
|
@ -69,6 +69,34 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
self.model_path = test_utils.get_test_data_path(
|
self.model_path = test_utils.get_test_data_path(
|
||||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _ImageEmbedder.create_from_model_path(self.model_path) as embedder:
|
||||||
|
self.assertIsInstance(embedder, _ImageEmbedder)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _ImageEmbedderOptions(base_options=base_options)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
self.assertIsInstance(embedder, _ImageEmbedder)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||||
|
base_options = _BaseOptions(
|
||||||
|
model_asset_path='/path/to/invalid/model.tflite')
|
||||||
|
options = _ImageEmbedderOptions(base_options=base_options)
|
||||||
|
_ImageEmbedder.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _ImageEmbedderOptions(base_options=base_options)
|
||||||
|
embedder = _ImageEmbedder.create_from_options(options)
|
||||||
|
self.assertIsInstance(embedder, _ImageEmbedder)
|
||||||
|
|
||||||
def _check_cosine_similarity(self, result0, result1, quantize,
|
def _check_cosine_similarity(self, result0, result1, quantize,
|
||||||
expected_similarity):
|
expected_similarity):
|
||||||
# Checks head_index and head_name.
|
# Checks head_index and head_name.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user