Merge pull request #4027 from kinaryml:cosine-sim-python

PiperOrigin-RevId: 518431807
This commit is contained in:
Copybara-Service 2023-03-21 17:59:48 -07:00
commit 5d2a719b54
2 changed files with 47 additions and 15 deletions

View File

@ -20,20 +20,20 @@ from mediapipe.tasks.python.components.containers import embedding_result
_Embedding = embedding_result.Embedding
def _compute_cosine_similarity(u, v):
def _compute_cosine_similarity(u: np.ndarray, v: np.ndarray):
"""Computes cosine similarity between two embeddings."""
if len(u.embedding) <= 0:
if len(u) <= 0:
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
norm_u = np.linalg.norm(u.embedding)
norm_v = np.linalg.norm(v.embedding)
norm_u = np.linalg.norm(u)
norm_v = np.linalg.norm(v)
if norm_u <= 0 or norm_v <= 0:
raise ValueError(
"Cannot compute cosine similarity on embedding with 0 norm.")
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
return u.dot(v) / (norm_u * norm_v)
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
@ -56,10 +56,13 @@ def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
f"({len(u.embedding)} vs. {len(v.embedding)}).")
if u.embedding.dtype == float and v.embedding.dtype == float:
return _compute_cosine_similarity(u, v)
return _compute_cosine_similarity(u.embedding, v.embedding)
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
return _compute_cosine_similarity(u, v)
return _compute_cosine_similarity(
u.embedding.view("int8").astype("float"),
v.embedding.view("int8").astype("float"),
)
raise ValueError("Cannot compute cosine similarity between quantized and "
"float embeddings.")

View File

@ -119,14 +119,43 @@ class ImageEmbedderTest(parameterized.TestCase):
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
@parameterized.parameters(
(False, False, False, ModelFileType.FILE_NAME, 0.925519, 1024,
(-0.2101883, -0.193027)),
(True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024,
(-0.0142344, -0.0131606)),
# (False, True, False, ModelFileType.FILE_NAME,
# 0.926791, 1024, (229, 231)),
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024,
(-0.195062, -0.193027)))
(
False,
False,
False,
ModelFileType.FILE_NAME,
0.925519,
1024,
(-0.2101883, -0.193027),
),
(
True,
False,
False,
ModelFileType.FILE_NAME,
0.925519,
1024,
(-0.0142344, -0.0131606),
),
(
False,
True,
False,
ModelFileType.FILE_NAME,
0.926791,
1024,
(229, 231),
),
(
False,
False,
True,
ModelFileType.FILE_CONTENT,
0.999931,
1024,
(-0.195062, -0.193027),
),
)
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
expected_similarity, expected_size, expected_first_values):
# Creates embedder.