Merge pull request #4027 from kinaryml:cosine-sim-python
PiperOrigin-RevId: 518431807
This commit is contained in:
commit
5d2a719b54
|
@ -20,20 +20,20 @@ from mediapipe.tasks.python.components.containers import embedding_result
|
||||||
_Embedding = embedding_result.Embedding
|
_Embedding = embedding_result.Embedding
|
||||||
|
|
||||||
|
|
||||||
def _compute_cosine_similarity(u, v):
|
def _compute_cosine_similarity(u: np.ndarray, v: np.ndarray):
|
||||||
"""Computes cosine similarity between two embeddings."""
|
"""Computes cosine similarity between two embeddings."""
|
||||||
|
|
||||||
if len(u.embedding) <= 0:
|
if len(u) <= 0:
|
||||||
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
|
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
|
||||||
|
|
||||||
norm_u = np.linalg.norm(u.embedding)
|
norm_u = np.linalg.norm(u)
|
||||||
norm_v = np.linalg.norm(v.embedding)
|
norm_v = np.linalg.norm(v)
|
||||||
|
|
||||||
if norm_u <= 0 or norm_v <= 0:
|
if norm_u <= 0 or norm_v <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot compute cosine similarity on embedding with 0 norm.")
|
"Cannot compute cosine similarity on embedding with 0 norm.")
|
||||||
|
|
||||||
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
return u.dot(v) / (norm_u * norm_v)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
||||||
|
@ -56,10 +56,13 @@ def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
||||||
f"({len(u.embedding)} vs. {len(v.embedding)}).")
|
f"({len(u.embedding)} vs. {len(v.embedding)}).")
|
||||||
|
|
||||||
if u.embedding.dtype == float and v.embedding.dtype == float:
|
if u.embedding.dtype == float and v.embedding.dtype == float:
|
||||||
return _compute_cosine_similarity(u, v)
|
return _compute_cosine_similarity(u.embedding, v.embedding)
|
||||||
|
|
||||||
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
|
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
|
||||||
return _compute_cosine_similarity(u, v)
|
return _compute_cosine_similarity(
|
||||||
|
u.embedding.view("int8").astype("float"),
|
||||||
|
v.embedding.view("int8").astype("float"),
|
||||||
|
)
|
||||||
|
|
||||||
raise ValueError("Cannot compute cosine similarity between quantized and "
|
raise ValueError("Cannot compute cosine similarity between quantized and "
|
||||||
"float embeddings.")
|
"float embeddings.")
|
||||||
|
|
|
@ -119,14 +119,43 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
|
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(False, False, False, ModelFileType.FILE_NAME, 0.925519, 1024,
|
(
|
||||||
(-0.2101883, -0.193027)),
|
False,
|
||||||
(True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024,
|
False,
|
||||||
(-0.0142344, -0.0131606)),
|
False,
|
||||||
# (False, True, False, ModelFileType.FILE_NAME,
|
ModelFileType.FILE_NAME,
|
||||||
# 0.926791, 1024, (229, 231)),
|
0.925519,
|
||||||
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024,
|
1024,
|
||||||
(-0.195062, -0.193027)))
|
(-0.2101883, -0.193027),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
ModelFileType.FILE_NAME,
|
||||||
|
0.925519,
|
||||||
|
1024,
|
||||||
|
(-0.0142344, -0.0131606),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
ModelFileType.FILE_NAME,
|
||||||
|
0.926791,
|
||||||
|
1024,
|
||||||
|
(229, 231),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
ModelFileType.FILE_CONTENT,
|
||||||
|
0.999931,
|
||||||
|
1024,
|
||||||
|
(-0.195062, -0.193027),
|
||||||
|
),
|
||||||
|
)
|
||||||
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
||||||
expected_similarity, expected_size, expected_first_values):
|
expected_similarity, expected_size, expected_first_values):
|
||||||
# Creates embedder.
|
# Creates embedder.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user