From c5ce5236972a6045f42bb23d526ebb27a7e58bb7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Nov 2022 02:02:18 -0800 Subject: [PATCH] Add cosine APIs to Embedder tasks PiperOrigin-RevId: 490444597 --- .../tasks/web/audio/audio_embedder/BUILD | 1 + .../audio/audio_embedder/audio_embedder.ts | 15 +++++ mediapipe/tasks/web/components/utils/BUILD | 11 ++++ .../web/components/utils/cosine_similarity.ts | 62 +++++++++++++++++++ mediapipe/tasks/web/text/text_embedder/BUILD | 1 + .../web/text/text_embedder/text_embedder.ts | 15 +++++ .../tasks/web/vision/image_embedder/BUILD | 1 + .../vision/image_embedder/image_embedder.ts | 15 +++++ 8 files changed, 121 insertions(+) create mode 100644 mediapipe/tasks/web/components/utils/BUILD create mode 100644 mediapipe/tasks/web/components/utils/cosine_similarity.ts diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 7d9a994a3..1a66464bd 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 46a7b6729..9dce02862 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -20,8 +20,10 @@ import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../.. import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -144,6 +146,19 @@ export class AudioEmbedder extends AudioTaskRunner { return this.processAudioClip(audioData, sampleRate); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD new file mode 100644 index 000000000..1c1ba69ca --- /dev/null +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -0,0 +1,11 @@ +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts new file mode 100644 index 000000000..fb1d0c185 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -0,0 +1,62 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +/** + * Computes cosine similarity[1] between two `Embedding` objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types (float vs. quantized), + * have different sizes, or have an L2-norm of 0. + */ +export function computeCosineSimilarity(u: Embedding, v: Embedding): number { + if (u.floatEmbedding && v.floatEmbedding) { + return compute(u.floatEmbedding, v.floatEmbedding); + } + if (u.quantizedEmbedding && v.quantizedEmbedding) { + return compute( + convertToBytes(u.quantizedEmbedding), + convertToBytes(v.quantizedEmbedding)); + } + throw new Error( + 'Cannot compute cosine similarity between quantized and float embeddings.'); +} +function convertToBytes(data: Uint8Array): number[] { + return Array.from(data, v => v - 128); +} + +function compute(u: number[], v: number[]) { + if (u.length !== v.length) { + throw new Error( + `Cannot compute cosine similarity between embeddings of different sizes (${ + u.length} vs. ${v.length}).`); + } + let dotProduct = 0.0; + let normU = 0.0; + let normV = 0.0; + for (let i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new Error( + 'Cannot compute cosine similarity on embedding with 0 norm.'); + } + return dotProduct / Math.sqrt(normU * normV); +} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index c555f8d33..3f92b8ae1 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 57b91d575..2042a0985 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -18,9 +18,11 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; @@ -143,6 +145,19 @@ export class TextEmbedder extends TaskRunner { return this.embeddingResult; } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Updates the MediaPipe graph configuration. */ private refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index feb3ae054..2f012dc5e 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -21,6 +21,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/vision/core:vision_task_options", diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index c60665052..f96f1e961 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -19,8 +19,10 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; @@ -157,6 +159,19 @@ export class ImageEmbedder extends VisionTaskRunner { return this.processVideoData(imageFrame, timestamp); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Runs the embedding extraction and blocks on the response. */ protected process(image: ImageSource, timestamp: number): ImageEmbedderResult {