Merge pull request #4027 from kinaryml:cosine-sim-python
PiperOrigin-RevId: 518431807
This commit is contained in:
commit
5d2a719b54
|
@ -20,20 +20,20 @@ from mediapipe.tasks.python.components.containers import embedding_result
|
|||
_Embedding = embedding_result.Embedding
|
||||
|
||||
|
||||
def _compute_cosine_similarity(u, v):
|
||||
def _compute_cosine_similarity(u: np.ndarray, v: np.ndarray):
|
||||
"""Computes cosine similarity between two embeddings."""
|
||||
|
||||
if len(u.embedding) <= 0:
|
||||
if len(u) <= 0:
|
||||
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
|
||||
|
||||
norm_u = np.linalg.norm(u.embedding)
|
||||
norm_v = np.linalg.norm(v.embedding)
|
||||
norm_u = np.linalg.norm(u)
|
||||
norm_v = np.linalg.norm(v)
|
||||
|
||||
if norm_u <= 0 or norm_v <= 0:
|
||||
raise ValueError(
|
||||
"Cannot compute cosine similarity on embedding with 0 norm.")
|
||||
|
||||
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
||||
return u.dot(v) / (norm_u * norm_v)
|
||||
|
||||
|
||||
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
||||
|
@ -56,10 +56,13 @@ def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
|
|||
f"({len(u.embedding)} vs. {len(v.embedding)}).")
|
||||
|
||||
if u.embedding.dtype == float and v.embedding.dtype == float:
|
||||
return _compute_cosine_similarity(u, v)
|
||||
return _compute_cosine_similarity(u.embedding, v.embedding)
|
||||
|
||||
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
|
||||
return _compute_cosine_similarity(u, v)
|
||||
return _compute_cosine_similarity(
|
||||
u.embedding.view("int8").astype("float"),
|
||||
v.embedding.view("int8").astype("float"),
|
||||
)
|
||||
|
||||
raise ValueError("Cannot compute cosine similarity between quantized and "
|
||||
"float embeddings.")
|
||||
|
|
|
@ -119,14 +119,43 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
|
||||
|
||||
@parameterized.parameters(
|
||||
(False, False, False, ModelFileType.FILE_NAME, 0.925519, 1024,
|
||||
(-0.2101883, -0.193027)),
|
||||
(True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024,
|
||||
(-0.0142344, -0.0131606)),
|
||||
# (False, True, False, ModelFileType.FILE_NAME,
|
||||
# 0.926791, 1024, (229, 231)),
|
||||
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024,
|
||||
(-0.195062, -0.193027)))
|
||||
(
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.925519,
|
||||
1024,
|
||||
(-0.2101883, -0.193027),
|
||||
),
|
||||
(
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.925519,
|
||||
1024,
|
||||
(-0.0142344, -0.0131606),
|
||||
),
|
||||
(
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.926791,
|
||||
1024,
|
||||
(229, 231),
|
||||
),
|
||||
(
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
ModelFileType.FILE_CONTENT,
|
||||
0.999931,
|
||||
1024,
|
||||
(-0.195062, -0.193027),
|
||||
),
|
||||
)
|
||||
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
||||
expected_similarity, expected_size, expected_first_values):
|
||||
# Creates embedder.
|
||||
|
|
Loading…
Reference in New Issue
Block a user