Remove cosineSimilarity() from AudioEmbedder

PiperOrigin-RevId: 512671255
This commit is contained in:
Sebastian Schmidt 2023-02-27 10:46:49 -08:00 committed by Copybara-Service
parent 39b2fec60f
commit 9f59d4d01b
11 changed files with 7 additions and 136 deletions

View File

@ -35,7 +35,6 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -147,10 +146,4 @@ absl::Status AudioEmbedder::EmbedAsync(Matrix audio_block,
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
absl::StatusOr<double> AudioEmbedder::CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v);
}
} // namespace mediapipe::tasks::audio::audio_embedder

View File

@ -125,16 +125,6 @@ class AudioEmbedder : core::BaseAudioTaskApi {
// Shuts down the AudioEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v);
};
} // namespace mediapipe::tasks::audio::audio_embedder

View File

@ -54,8 +54,6 @@ constexpr char kModelWithMetadata[] = "yamnet_embedding_metadata.tflite";
constexpr char k16kTestWavFilename[] = "speech_16000_hz_mono.wav";
constexpr char k48kTestWavFilename[] = "speech_48000_hz_mono.wav";
constexpr char k16kTestWavForTwoHeadsFilename[] = "two_heads_16000_hz_mono.wav";
constexpr float kSpeechSimilarities[] = {0.985359, 0.994349, 0.993227, 0.996658,
0.996384};
constexpr int kMilliSecondsPerSecond = 1000;
constexpr int kYamnetNumOfAudioSamples = 15600;
constexpr int kYamnetAudioSampleRate = 16000;
@ -163,15 +161,9 @@ TEST_F(EmbedTest, SucceedsWithSameAudioAtDifferentSampleRates) {
audio_embedder->Embed(audio_buffer1, 16000));
MP_ASSERT_OK_AND_ASSIGN(auto result2,
audio_embedder->Embed(audio_buffer2, 48000));
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
int expected_size = 5;
ASSERT_EQ(result1.size(), expected_size);
ASSERT_EQ(result2.size(), expected_size);
for (int i = 0; i < expected_size; ++i) {
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[i].embeddings[0],
result2[i].embeddings[0]));
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
}
MP_EXPECT_OK(audio_embedder->Close());
}
@ -192,10 +184,6 @@ TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
audio_embedder->Embed(audio_buffer2, kYamnetAudioSampleRate));
ASSERT_EQ(result1.size(), 5);
ASSERT_EQ(result2.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[0].embeddings[0],
result2[0].embeddings[0]));
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
MP_EXPECT_OK(audio_embedder->Close());
}
@ -258,15 +246,9 @@ TEST_F(EmbedAsyncTest, SucceedsWithSameAudioAtDifferentSampleRates) {
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
std::vector<AudioEmbedderResult> result2;
RunAudioEmbedderInStreamMode(k48kTestWavFilename, 48000, &result2);
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
int expected_size = 5;
ASSERT_EQ(result1.size(), expected_size);
ASSERT_EQ(result2.size(), expected_size);
for (int i = 0; i < expected_size; ++i) {
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[i].embeddings[0],
result2[i].embeddings[0]));
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
}
}
TEST_F(EmbedAsyncTest, SucceedsWithDifferentAudios) {
@ -276,10 +258,6 @@ TEST_F(EmbedAsyncTest, SucceedsWithDifferentAudios) {
RunAudioEmbedderInStreamMode(k16kTestWavForTwoHeadsFilename, 16000, &result2);
ASSERT_EQ(result1.size(), 5);
ASSERT_EQ(result2.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[0].embeddings[0],
result2[0].embeddings[0]));
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
}
} // namespace

View File

@ -101,9 +101,7 @@ android_library(
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",

View File

@ -26,10 +26,8 @@ import com.google.mediapipe.tasks.audio.audioembedder.proto.AudioEmbedderGraphOp
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
import com.google.mediapipe.tasks.audio.core.RunningMode;
import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.components.containers.Embedding;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
@ -273,17 +271,6 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
sendAudioStreamData(audioBlock, timestampMs);
}
/**
* Utility function to compute <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine
* similarity</a> between two {@link Embedding} objects.
*
* @throws IllegalArgumentException if the embeddings are of different types (float vs.
* quantized), have different sizes, or have an L2-norm of 0.
*/
public static double cosineSimilarity(Embedding u, Embedding v) {
return CosineSimilarity.compute(u, v);
}
/** Options for setting up and {@link AudioEmbedder}. */
@AutoValue
public abstract static class AudioEmbedderOptions extends TaskOptions {

View File

@ -56,7 +56,6 @@ py_library(
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",

View File

@ -26,7 +26,6 @@ from mediapipe.tasks.python.audio.core import audio_task_running_mode as running
from mediapipe.tasks.python.audio.core import base_audio_task_api
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -284,26 +283,3 @@ class AudioEmbedder(base_audio_task_api.BaseAudioTaskApi):
packet_creator.create_matrix(audio_block.buffer, transpose=True).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
@classmethod
def cosine_similarity(cls, u: embedding_result_module.Embedding,
v: embedding_result_module.Embedding) -> float:
"""Utility function to compute cosine similarity between two embedding entries.
May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a
an L2-norm of 0.
Args:
u: An embedding entry.
v: An embedding entry.
Returns:
The cosine similarity for the two embeddings.
Raises:
ValueError: May return an error if e.g. the feature vectors are of
different types (quantized vs. float), have different sizes, or have
an L2-norm of 0.
"""
return cosine_similarity.cosine_similarity(u, v)

View File

@ -42,13 +42,10 @@ _SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav'
_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav'
_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio'
_SPEECH_SIMILARITIES = [0.985359, 0.994349, 0.993227, 0.996658, 0.996384]
_YAMNET_NUM_OF_SAMPLES = 15600
_MILLSECONDS_PER_SECOND = 1000
# Tolerance for embedding vector coordinate values.
_EPSILON = 3e-6
# Tolerance for cosine similarity evaluation.
_SIMILARITY_TOLERANCE = 1e-6
class ModelFileType(enum.Enum):
@ -98,27 +95,6 @@ class AudioEmbedderTest(parameterized.TestCase):
else:
self.assertEqual(embedding_result.embedding.dtype, float)
def _check_cosine_similarity(self, result0, result1, expected_similarity):
# Checks cosine similarity.
similarity = _AudioEmbedder.cosine_similarity(result0.embeddings[0],
result1.embeddings[0])
self.assertAlmostEqual(
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
def _check_yamnet_result(self,
embedding_result0_list: List[_AudioEmbedderResult],
embedding_result1_list: List[_AudioEmbedderResult],
expected_similarities: List[float]):
expected_size = len(expected_similarities)
self.assertLen(embedding_result0_list, expected_size)
self.assertLen(embedding_result1_list, expected_size)
for idx in range(expected_size):
embedding_result0 = embedding_result0_list[idx]
embedding_result1 = embedding_result1_list[idx]
self._check_cosine_similarity(embedding_result0, embedding_result1,
expected_similarities[idx])
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _AudioEmbedder.create_from_model_path(
@ -176,7 +152,7 @@ class AudioEmbedderTest(parameterized.TestCase):
embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0))
embedding_result1_list = embedder.embed(self._read_wav_file(audio_file1))
# Checks embeddings and cosine similarity.
# Checks embeddings.
expected_result0_value, expected_result1_value = expected_first_values
self._check_embedding_size(embedding_result0_list[0], quantize,
expected_size)
@ -186,10 +162,8 @@ class AudioEmbedderTest(parameterized.TestCase):
expected_result0_value)
self._check_embedding_value(embedding_result1_list[0],
expected_result1_value)
self._check_yamnet_result(
embedding_result0_list,
embedding_result1_list,
expected_similarities=_SPEECH_SIMILARITIES)
self.assertLen(embedding_result0_list, 5)
self.assertLen(embedding_result1_list, 5)
def test_embed_with_yamnet_model_and_different_inputs(self):
with _AudioEmbedder.create_from_model_path(
@ -200,10 +174,6 @@ class AudioEmbedderTest(parameterized.TestCase):
self._read_wav_file(_TWO_HEADS_WAV_16K_MONO))
self.assertLen(embedding_result0_list, 5)
self.assertLen(embedding_result1_list, 1)
self._check_cosine_similarity(
embedding_result0_list[0],
embedding_result1_list[0],
expected_similarity=0.09017)
def test_missing_sample_rate_in_audio_clips_mode(self):
options = _AudioEmbedderOptions(
@ -304,10 +274,8 @@ class AudioEmbedderTest(parameterized.TestCase):
embedder.embed_async(audio_data, timestamp_ms)
embedding_result1_list = embedding_result_list
self._check_yamnet_result(
embedding_result0_list,
embedding_result1_list,
expected_similarities=_SPEECH_SIMILARITIES)
self.assertLen(embedding_result0_list, 5)
self.assertLen(embedding_result1_list, 5)
if __name__ == '__main__':

View File

@ -21,10 +21,8 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/audio/core:audio_task_runner",
"//mediapipe/tasks/web/components/containers:embedding_result",
"//mediapipe/tasks/web/components/processors:embedder_options",
"//mediapipe/tasks/web/components/processors:embedder_result",
"//mediapipe/tasks/web/components/utils:cosine_similarity",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options",
"//mediapipe/tasks/web/core:task_runner",

View File

@ -20,10 +20,8 @@ import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../..
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
@ -145,19 +143,6 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
return this.processAudioClip(audioData, sampleRate);
}
/**
* Utility function to compute cosine similarity[1] between two `Embedding`
* objects.
*
* [1]: https://en.wikipedia.org/wiki/Cosine_similarity
*
* @throws if the embeddings are of different types(float vs. quantized), have
* different sizes, or have an L2-norm of 0.
*/
static cosineSimilarity(u: Embedding, v: Embedding): number {
return computeCosineSimilarity(u, v);
}
protected override process(
audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioEmbedderResult[] {