Updated cosine similarity utility

This commit is contained in:
kinaryml 2022-11-17 14:03:07 -08:00
parent 3ccf7308e0
commit 87238705dd
2 changed files with 235 additions and 234 deletions

View File

@ -22,20 +22,20 @@ _Embedding = embedding_result.Embedding
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptions = embedder_options.EmbedderOptions
def _compute_cosine_similarity(u, v): def _compute_cosine_similarity(u: np.ndarray, v: np.ndarray):
"""Computes cosine similarity between two embeddings.""" """Computes cosine similarity between two embeddings."""
if len(u.embedding) <= 0: if len(u) <= 0:
raise ValueError("Cannot compute cosing similarity on empty embeddings.") raise ValueError("Cannot compute cosing similarity on empty embeddings.")
norm_u = np.linalg.norm(u.embedding) norm_u = np.linalg.norm(u)
norm_v = np.linalg.norm(v.embedding) norm_v = np.linalg.norm(v)
if norm_u <= 0 or norm_v <= 0: if norm_u <= 0 or norm_v <= 0:
raise ValueError( raise ValueError(
"Cannot compute cosine similarity on embedding with 0 norm.") "Cannot compute cosine similarity on embedding with 0 norm.")
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v) return u.dot(v) / (norm_u * norm_v)
def cosine_similarity(u: _Embedding, v: _Embedding) -> float: def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
@ -58,10 +58,11 @@ def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
f"({len(u.embedding)} vs. {len(v.embedding)}).") f"({len(u.embedding)} vs. {len(v.embedding)}).")
if u.embedding.dtype == float and v.embedding.dtype == float: if u.embedding.dtype == float and v.embedding.dtype == float:
return _compute_cosine_similarity(u, v) return _compute_cosine_similarity(u.embedding, v.embedding)
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8: if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
return _compute_cosine_similarity(u, v) return _compute_cosine_similarity(u.embedding.astype('float'),
v.embedding.astype('float'))
raise ValueError("Cannot compute cosine similarity between quantized and " raise ValueError("Cannot compute cosine similarity between quantized and "
"float embeddings.") "float embeddings.")

View File

@ -125,8 +125,8 @@ class ImageEmbedderTest(parameterized.TestCase):
(-0.2101883, -0.193027)), (-0.2101883, -0.193027)),
(True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024, (True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024,
(-0.0142344, -0.0131606)), (-0.0142344, -0.0131606)),
# (False, True, False, ModelFileType.FILE_NAME, (False, True, False, ModelFileType.FILE_NAME,
# 0.926791, 1024, (229, 231)), 0.906201, 1024, (229, 231)),
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024, (False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024,
(-0.195062, -0.193027))) (-0.195062, -0.193027)))
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type, def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
@ -169,231 +169,231 @@ class ImageEmbedderTest(parameterized.TestCase):
# Closes the embedder explicitly when the embedder is not used in # Closes the embedder explicitly when the embedder is not used in
# a context. # a context.
embedder.close() embedder.close()
#
@parameterized.parameters( # @parameterized.parameters(
(False, False, ModelFileType.FILE_NAME, 0.925519), # (False, False, ModelFileType.FILE_NAME, 0.925519),
(False, False, ModelFileType.FILE_CONTENT, 0.925519)) # (False, False, ModelFileType.FILE_CONTENT, 0.925519))
def test_embed_in_context(self, l2_normalize, quantize, model_file_type, # def test_embed_in_context(self, l2_normalize, quantize, model_file_type,
expected_similarity): # expected_similarity):
# Creates embedder. # # Creates embedder.
if model_file_type is ModelFileType.FILE_NAME: # if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) # base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT: # elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f: # with open(self.model_path, 'rb') as f:
model_content = f.read() # model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content) # base_options = _BaseOptions(model_asset_buffer=model_content)
else: # else:
# Should never happen # # Should never happen
raise ValueError('model_file_type is invalid.') # raise ValueError('model_file_type is invalid.')
#
embedder_options = _EmbedderOptions( # embedder_options = _EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize) # l2_normalize=l2_normalize, quantize=quantize)
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=base_options, embedder_options=embedder_options) # base_options=base_options, embedder_options=embedder_options)
#
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
# Extracts both embeddings. # # Extracts both embeddings.
image_result = embedder.embed(self.test_image) # image_result = embedder.embed(self.test_image)
crop_result = embedder.embed(self.test_cropped_image) # crop_result = embedder.embed(self.test_cropped_image)
#
# Checks cosine similarity. # # Checks cosine similarity.
self._check_cosine_similarity(image_result, crop_result, # self._check_cosine_similarity(image_result, crop_result,
expected_similarity) # expected_similarity)
#
def test_missing_result_callback(self): # def test_missing_result_callback(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM) # running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'result callback must be provided'): # r'result callback must be provided'):
with _ImageEmbedder.create_from_options(options) as unused_embedder: # with _ImageEmbedder.create_from_options(options) as unused_embedder:
pass # pass
#
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) # @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode): # def test_illegal_result_callback(self, running_mode):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode, # running_mode=running_mode,
result_callback=mock.MagicMock()) # result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'): # r'result callback should not be provided'):
with _ImageEmbedder.create_from_options(options) as unused_embedder: # with _ImageEmbedder.create_from_options(options) as unused_embedder:
pass # pass
#
def test_calling_embed_for_video_in_image_mode(self): # def test_calling_embed_for_video_in_image_mode(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) # running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'): # r'not initialized with the video mode'):
embedder.embed_for_video(self.test_image, 0) # embedder.embed_for_video(self.test_image, 0)
#
def test_calling_embed_async_in_image_mode(self): # def test_calling_embed_async_in_image_mode(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) # running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'): # r'not initialized with the live stream mode'):
embedder.embed_async(self.test_image, 0) # embedder.embed_async(self.test_image, 0)
#
def test_calling_embed_in_video_mode(self): # def test_calling_embed_in_video_mode(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) # running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'): # r'not initialized with the image mode'):
embedder.embed(self.test_image) # embedder.embed(self.test_image)
#
def test_calling_embed_async_in_video_mode(self): # def test_calling_embed_async_in_video_mode(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) # running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'): # r'not initialized with the live stream mode'):
embedder.embed_async(self.test_image, 0) # embedder.embed_async(self.test_image, 0)
#
def test_embed_for_video_with_out_of_order_timestamp(self): # def test_embed_for_video_with_out_of_order_timestamp(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) # running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
unused_result = embedder.embed_for_video(self.test_image, 1) # unused_result = embedder.embed_for_video(self.test_image, 1)
with self.assertRaisesRegex( # with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): # ValueError, r'Input timestamp must be monotonically increasing'):
embedder.embed_for_video(self.test_image, 0) # embedder.embed_for_video(self.test_image, 0)
#
def test_embed_for_video(self): # def test_embed_for_video(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) # running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder0, \ # with _ImageEmbedder.create_from_options(options) as embedder0, \
_ImageEmbedder.create_from_options(options) as embedder1: # _ImageEmbedder.create_from_options(options) as embedder1:
for timestamp in range(0, 300, 30): # for timestamp in range(0, 300, 30):
# Extracts both embeddings. # # Extracts both embeddings.
image_result = embedder0.embed_for_video(self.test_image, timestamp) # image_result = embedder0.embed_for_video(self.test_image, timestamp)
crop_result = embedder1.embed_for_video(self.test_cropped_image, # crop_result = embedder1.embed_for_video(self.test_cropped_image,
timestamp) # timestamp)
# Checks cosine similarity. # # Checks cosine similarity.
self._check_cosine_similarity( # self._check_cosine_similarity(
image_result, crop_result, expected_similarity=0.925519) # image_result, crop_result, expected_similarity=0.925519)
#
def test_embed_for_video_succeeds_with_region_of_interest(self): # def test_embed_for_video_succeeds_with_region_of_interest(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) # running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder0, \ # with _ImageEmbedder.create_from_options(options) as embedder0, \
_ImageEmbedder.create_from_options(options) as embedder1: # _ImageEmbedder.create_from_options(options) as embedder1:
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". # # Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
roi = _Rect(left=0, top=0, right=0.833333, bottom=1) # roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
image_processing_options = _ImageProcessingOptions(roi) # image_processing_options = _ImageProcessingOptions(roi)
#
for timestamp in range(0, 300, 30): # for timestamp in range(0, 300, 30):
# Extracts both embeddings. # # Extracts both embeddings.
image_result = embedder0.embed_for_video(self.test_image, timestamp, # image_result = embedder0.embed_for_video(self.test_image, timestamp,
image_processing_options) # image_processing_options)
crop_result = embedder1.embed_for_video(self.test_cropped_image, # crop_result = embedder1.embed_for_video(self.test_cropped_image,
timestamp) # timestamp)
#
# Checks cosine similarity. # # Checks cosine similarity.
self._check_cosine_similarity( # self._check_cosine_similarity(
image_result, crop_result, expected_similarity=0.999931) # image_result, crop_result, expected_similarity=0.999931)
#
def test_calling_embed_in_live_stream_mode(self): # def test_calling_embed_in_live_stream_mode(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, # running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) # result_callback=mock.MagicMock())
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'): # r'not initialized with the image mode'):
embedder.embed(self.test_image) # embedder.embed(self.test_image)
#
def test_calling_embed_for_video_in_live_stream_mode(self): # def test_calling_embed_for_video_in_live_stream_mode(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, # running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) # result_callback=mock.MagicMock())
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError, # with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'): # r'not initialized with the video mode'):
embedder.embed_for_video(self.test_image, 0) # embedder.embed_for_video(self.test_image, 0)
#
def test_embed_async_calls_with_illegal_timestamp(self): # def test_embed_async_calls_with_illegal_timestamp(self):
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, # running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) # result_callback=mock.MagicMock())
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
embedder.embed_async(self.test_image, 100) # embedder.embed_async(self.test_image, 100)
with self.assertRaisesRegex( # with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): # ValueError, r'Input timestamp must be monotonically increasing'):
embedder.embed_async(self.test_image, 0) # embedder.embed_async(self.test_image, 0)
#
def test_embed_async_calls(self): # def test_embed_async_calls(self):
# Get the embedding result for the cropped image. # # Get the embedding result for the cropped image.
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) # running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
crop_result = embedder.embed(self.test_cropped_image) # crop_result = embedder.embed(self.test_cropped_image)
#
observed_timestamp_ms = -1 # observed_timestamp_ms = -1
#
def check_result(result: _ImageEmbedderResult, output_image: _Image, # def check_result(result: _ImageEmbedderResult, output_image: _Image,
timestamp_ms: int): # timestamp_ms: int):
# Checks cosine similarity. # # Checks cosine similarity.
self._check_cosine_similarity( # self._check_cosine_similarity(
result, crop_result, expected_similarity=0.925519) # result, crop_result, expected_similarity=0.925519)
self.assertTrue( # self.assertTrue(
np.array_equal(output_image.numpy_view(), # np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view())) # self.test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms) # self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms # self.observed_timestamp_ms = timestamp_ms
#
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, # running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result) # result_callback=check_result)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
for timestamp in range(0, 300, 30): # for timestamp in range(0, 300, 30):
embedder.embed_async(self.test_image, timestamp) # embedder.embed_async(self.test_image, timestamp)
#
def test_embed_async_succeeds_with_region_of_interest(self): # def test_embed_async_succeeds_with_region_of_interest(self):
# Get the embedding result for the cropped image. # # Get the embedding result for the cropped image.
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) # running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
crop_result = embedder.embed(self.test_cropped_image) # crop_result = embedder.embed(self.test_cropped_image)
#
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". # # Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
roi = _Rect(left=0, top=0, right=0.833333, bottom=1) # roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
image_processing_options = _ImageProcessingOptions(roi) # image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1 # observed_timestamp_ms = -1
#
def check_result(result: _ImageEmbedderResult, output_image: _Image, # def check_result(result: _ImageEmbedderResult, output_image: _Image,
timestamp_ms: int): # timestamp_ms: int):
# Checks cosine similarity. # # Checks cosine similarity.
self._check_cosine_similarity( # self._check_cosine_similarity(
result, crop_result, expected_similarity=0.999931) # result, crop_result, expected_similarity=0.999931)
self.assertTrue( # self.assertTrue(
np.array_equal(output_image.numpy_view(), # np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view())) # self.test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms) # self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms # self.observed_timestamp_ms = timestamp_ms
#
options = _ImageEmbedderOptions( # options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), # base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, # running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result) # result_callback=check_result)
with _ImageEmbedder.create_from_options(options) as embedder: # with _ImageEmbedder.create_from_options(options) as embedder:
for timestamp in range(0, 300, 30): # for timestamp in range(0, 300, 30):
embedder.embed_async(self.test_image, timestamp, # embedder.embed_async(self.test_image, timestamp,
image_processing_options) # image_processing_options)
if __name__ == '__main__': if __name__ == '__main__':