diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index ff8979458..a6245a579 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -20,20 +20,20 @@ from mediapipe.tasks.python.components.containers import embedding_result _Embedding = embedding_result.Embedding -def _compute_cosine_similarity(u, v): +def _compute_cosine_similarity(u: np.ndarray, v: np.ndarray): """Computes cosine similarity between two embeddings.""" - if len(u.embedding) <= 0: + if len(u) <= 0: raise ValueError("Cannot compute cosing similarity on empty embeddings.") - norm_u = np.linalg.norm(u.embedding) - norm_v = np.linalg.norm(v.embedding) + norm_u = np.linalg.norm(u) + norm_v = np.linalg.norm(v) if norm_u <= 0 or norm_v <= 0: raise ValueError( "Cannot compute cosine similarity on embedding with 0 norm.") - return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v) + return u.dot(v) / (norm_u * norm_v) def cosine_similarity(u: _Embedding, v: _Embedding) -> float: @@ -56,10 +56,13 @@ def cosine_similarity(u: _Embedding, v: _Embedding) -> float: f"({len(u.embedding)} vs. {len(v.embedding)}).") if u.embedding.dtype == float and v.embedding.dtype == float: - return _compute_cosine_similarity(u, v) + return _compute_cosine_similarity(u.embedding, v.embedding) if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8: - return _compute_cosine_similarity(u, v) + return _compute_cosine_similarity( + u.embedding.view("int8").astype("float"), + v.embedding.view("int8").astype("float"), + ) raise ValueError("Cannot compute cosine similarity between quantized and " "float embeddings.") diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 11c0cf002..8c7fb59a2 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -119,14 +119,43 @@ class ImageEmbedderTest(parameterized.TestCase): similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) @parameterized.parameters( - (False, False, False, ModelFileType.FILE_NAME, 0.925519, 1024, - (-0.2101883, -0.193027)), - (True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024, - (-0.0142344, -0.0131606)), - # (False, True, False, ModelFileType.FILE_NAME, - # 0.926791, 1024, (229, 231)), - (False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024, - (-0.195062, -0.193027))) + ( + False, + False, + False, + ModelFileType.FILE_NAME, + 0.925519, + 1024, + (-0.2101883, -0.193027), + ), + ( + True, + False, + False, + ModelFileType.FILE_NAME, + 0.925519, + 1024, + (-0.0142344, -0.0131606), + ), + ( + False, + True, + False, + ModelFileType.FILE_NAME, + 0.926791, + 1024, + (229, 231), + ), + ( + False, + False, + True, + ModelFileType.FILE_CONTENT, + 0.999931, + 1024, + (-0.195062, -0.193027), + ), + ) def test_embed(self, l2_normalize, quantize, with_roi, model_file_type, expected_similarity, expected_size, expected_first_values): # Creates embedder.