Merge pull request #4027 from kinaryml:cosine-sim-python

PiperOrigin-RevId: 518431807
This commit is contained in:
Copybara-Service 2023-03-21 17:59:48 -07:00
commit 5d2a719b54
2 changed files with 47 additions and 15 deletions

View File

@ -20,20 +20,20 @@ from mediapipe.tasks.python.components.containers import embedding_result
_Embedding = embedding_result.Embedding _Embedding = embedding_result.Embedding
def _compute_cosine_similarity(u, v): def _compute_cosine_similarity(u: np.ndarray, v: np.ndarray):
"""Computes cosine similarity between two embeddings.""" """Computes cosine similarity between two embeddings."""
if len(u.embedding) <= 0: if len(u) <= 0:
raise ValueError("Cannot compute cosing similarity on empty embeddings.") raise ValueError("Cannot compute cosing similarity on empty embeddings.")
norm_u = np.linalg.norm(u.embedding) norm_u = np.linalg.norm(u)
norm_v = np.linalg.norm(v.embedding) norm_v = np.linalg.norm(v)
if norm_u <= 0 or norm_v <= 0: if norm_u <= 0 or norm_v <= 0:
raise ValueError( raise ValueError(
"Cannot compute cosine similarity on embedding with 0 norm.") "Cannot compute cosine similarity on embedding with 0 norm.")
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v) return u.dot(v) / (norm_u * norm_v)
def cosine_similarity(u: _Embedding, v: _Embedding) -> float: def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
@ -56,10 +56,13 @@ def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
f"({len(u.embedding)} vs. {len(v.embedding)}).") f"({len(u.embedding)} vs. {len(v.embedding)}).")
if u.embedding.dtype == float and v.embedding.dtype == float: if u.embedding.dtype == float and v.embedding.dtype == float:
return _compute_cosine_similarity(u, v) return _compute_cosine_similarity(u.embedding, v.embedding)
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8: if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
return _compute_cosine_similarity(u, v) return _compute_cosine_similarity(
u.embedding.view("int8").astype("float"),
v.embedding.view("int8").astype("float"),
)
raise ValueError("Cannot compute cosine similarity between quantized and " raise ValueError("Cannot compute cosine similarity between quantized and "
"float embeddings.") "float embeddings.")

View File

@ -119,14 +119,43 @@ class ImageEmbedderTest(parameterized.TestCase):
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
@parameterized.parameters( @parameterized.parameters(
(False, False, False, ModelFileType.FILE_NAME, 0.925519, 1024, (
(-0.2101883, -0.193027)), False,
(True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024, False,
(-0.0142344, -0.0131606)), False,
# (False, True, False, ModelFileType.FILE_NAME, ModelFileType.FILE_NAME,
# 0.926791, 1024, (229, 231)), 0.925519,
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024, 1024,
(-0.195062, -0.193027))) (-0.2101883, -0.193027),
),
(
True,
False,
False,
ModelFileType.FILE_NAME,
0.925519,
1024,
(-0.0142344, -0.0131606),
),
(
False,
True,
False,
ModelFileType.FILE_NAME,
0.926791,
1024,
(229, 231),
),
(
False,
False,
True,
ModelFileType.FILE_CONTENT,
0.999931,
1024,
(-0.195062, -0.193027),
),
)
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type, def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
expected_similarity, expected_size, expected_first_values): expected_similarity, expected_size, expected_first_values):
# Creates embedder. # Creates embedder.