Add cosine APIs to Embedder tasks

PiperOrigin-RevId: 490444597
This commit is contained in:
Sebastian Schmidt 2022-11-23 02:02:18 -08:00 committed by Copybara-Service
parent 05681fc0e1
commit c5ce523697
8 changed files with 121 additions and 0 deletions

View File

@ -22,6 +22,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/containers:embedding_result",
"//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_options",
"//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/components/processors:embedder_result",
"//mediapipe/tasks/web/components/utils:cosine_similarity",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:embedder_options",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",

View File

@ -20,8 +20,10 @@ import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../..
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
@ -144,6 +146,19 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
return this.processAudioClip(audioData, sampleRate); return this.processAudioClip(audioData, sampleRate);
} }
/**
* Utility function to compute cosine similarity[1] between two `Embedding`
* objects.
*
* [1]: https://en.wikipedia.org/wiki/Cosine_similarity
*
* @throws if the embeddings are of different types(float vs. quantized), have
* different sizes, or have an L2-norm of 0.
*/
static cosineSimilarity(u: Embedding, v: Embedding): number {
return computeCosineSimilarity(u, v);
}
protected override process( protected override process(
audioData: Float32Array, sampleRate: number, audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioEmbedderResult[] { timestampMs: number): AudioEmbedderResult[] {

View File

@ -0,0 +1,11 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library(
name = "cosine_similarity",
srcs = ["cosine_similarity.ts"],
deps = [
"//mediapipe/tasks/web/components/containers:embedding_result",
],
)

View File

@ -0,0 +1,62 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* <p>Licensed under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. You may obtain a
* copy of the License at
*
* <p>http://www.apache.org/licenses/LICENSE-2.0
*
* <p>Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
/**
* Computes cosine similarity[1] between two `Embedding` objects.
*
* [1]: https://en.wikipedia.org/wiki/Cosine_similarity
*
* @throws if the embeddings are of different types (float vs. quantized),
* have different sizes, or have an L2-norm of 0.
*/
export function computeCosineSimilarity(u: Embedding, v: Embedding): number {
if (u.floatEmbedding && v.floatEmbedding) {
return compute(u.floatEmbedding, v.floatEmbedding);
}
if (u.quantizedEmbedding && v.quantizedEmbedding) {
return compute(
convertToBytes(u.quantizedEmbedding),
convertToBytes(v.quantizedEmbedding));
}
throw new Error(
'Cannot compute cosine similarity between quantized and float embeddings.');
}
function convertToBytes(data: Uint8Array): number[] {
return Array.from(data, v => v - 128);
}
function compute(u: number[], v: number[]) {
if (u.length !== v.length) {
throw new Error(
`Cannot compute cosine similarity between embeddings of different sizes (${
u.length} vs. ${v.length}).`);
}
let dotProduct = 0.0;
let normU = 0.0;
let normV = 0.0;
for (let i = 0; i < u.length; i++) {
dotProduct += u[i] * v[i];
normU += u[i] * u[i];
normV += v[i] * v[i];
}
if (normU <= 0 || normV <= 0) {
throw new Error(
'Cannot compute cosine similarity on embedding with 0 norm.');
}
return dotProduct / Math.sqrt(normU * normV);
}

View File

@ -22,6 +22,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_options",
"//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/components/processors:embedder_result",
"//mediapipe/tasks/web/components/utils:cosine_similarity",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:embedder_options",
"//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner",

View File

@ -18,9 +18,11 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb';
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
@ -143,6 +145,19 @@ export class TextEmbedder extends TaskRunner {
return this.embeddingResult; return this.embeddingResult;
} }
/**
* Utility function to compute cosine similarity[1] between two `Embedding`
* objects.
*
* [1]: https://en.wikipedia.org/wiki/Cosine_similarity
*
* @throws if the embeddings are of different types(float vs. quantized), have
* different sizes, or have an L2-norm of 0.
*/
static cosineSimilarity(u: Embedding, v: Embedding): number {
return computeCosineSimilarity(u, v);
}
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();

View File

@ -21,6 +21,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/containers:embedding_result",
"//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_options",
"//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/components/processors:embedder_result",
"//mediapipe/tasks/web/components/utils:cosine_similarity",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:embedder_options",
"//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_options",

View File

@ -19,8 +19,10 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb';
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner';
@ -157,6 +159,19 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
return this.processVideoData(imageFrame, timestamp); return this.processVideoData(imageFrame, timestamp);
} }
/**
* Utility function to compute cosine similarity[1] between two `Embedding`
* objects.
*
* [1]: https://en.wikipedia.org/wiki/Cosine_similarity
*
* @throws if the embeddings are of different types(float vs. quantized), have
* different sizes, or have an L2-norm of 0.
*/
static cosineSimilarity(u: Embedding, v: Embedding): number {
return computeCosineSimilarity(u, v);
}
/** Runs the embedding extraction and blocks on the response. */ /** Runs the embedding extraction and blocks on the response. */
protected process(image: ImageSource, timestamp: number): protected process(image: ImageSource, timestamp: number):
ImageEmbedderResult { ImageEmbedderResult {