MediaPipe Tasks Audio embedder C++ API.
PiperOrigin-RevId: 488273381
This commit is contained in:
parent
0dfa91a166
commit
6c0ca947de
79
mediapipe/tasks/cc/audio/audio_embedder/BUILD
Normal file
79
mediapipe/tasks/cc/audio/audio_embedder/BUILD
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "audio_embedder_graph",
|
||||||
|
srcs = ["audio_embedder_graph.cc"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/audio:time_series_framer_calculator",
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
|
||||||
|
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:audio_to_tensor_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto",
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:matrix",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "audio_embedder",
|
||||||
|
srcs = ["audio_embedder.cc"],
|
||||||
|
hdrs = ["audio_embedder.h"],
|
||||||
|
deps = [
|
||||||
|
":audio_embedder_graph",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/formats:matrix",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||||
|
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
||||||
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build
|
156
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc
Normal file
156
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::audio::audio_embedder {
|
||||||
|
namespace {
|
||||||
|
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
constexpr char kAudioStreamName[] = "audio_in";
|
||||||
|
constexpr char kAudioTag[] = "AUDIO";
|
||||||
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kEmbeddingsName[] = "embeddings_out";
|
||||||
|
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||||
|
constexpr char kSampleRateName[] = "sample_rate_in";
|
||||||
|
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||||
|
constexpr char kSubgraphTypeName[] =
|
||||||
|
"mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph";
|
||||||
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
|
// type "AudioEmbedderGraph".
|
||||||
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
|
std::unique_ptr<proto::AudioEmbedderGraphOptions> options_proto) {
|
||||||
|
api2::builder::Graph graph;
|
||||||
|
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
|
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
||||||
|
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
||||||
|
subgraph.In(kSampleRateTag);
|
||||||
|
subgraph.GetOptions<proto::AudioEmbedderGraphOptions>().Swap(
|
||||||
|
options_proto.get());
|
||||||
|
subgraph.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
|
||||||
|
graph.Out(kEmbeddingsTag);
|
||||||
|
subgraph.Out(kTimestampedEmbeddingsTag).SetName(kTimestampedEmbeddingsName) >>
|
||||||
|
graph.Out(kTimestampedEmbeddingsTag);
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the user-facing AudioEmbedderOptions struct to the internal
|
||||||
|
// AudioEmbedderGraphOptions proto.
|
||||||
|
std::unique_ptr<proto::AudioEmbedderGraphOptions>
|
||||||
|
ConvertAudioEmbedderOptionsToProto(AudioEmbedderOptions* options) {
|
||||||
|
auto options_proto = std::make_unique<proto::AudioEmbedderGraphOptions>();
|
||||||
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
|
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
||||||
|
auto embedder_options_proto =
|
||||||
|
std::make_unique<components::processors::proto::EmbedderOptions>(
|
||||||
|
components::processors::ConvertEmbedderOptionsToProto(
|
||||||
|
&(options->embedder_options)));
|
||||||
|
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
||||||
|
return options_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<AudioEmbedderResult>> ConvertOutputPackets(
|
||||||
|
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
|
if (!status_or_packets.ok()) {
|
||||||
|
return status_or_packets.status();
|
||||||
|
}
|
||||||
|
auto embedding_results = status_or_packets.value()[kTimestampedEmbeddingsName]
|
||||||
|
.Get<std::vector<EmbeddingResult>>();
|
||||||
|
std::vector<AudioEmbedderResult> results;
|
||||||
|
results.reserve(embedding_results.size());
|
||||||
|
for (const auto& embedding_result : embedding_results) {
|
||||||
|
results.emplace_back(ConvertToEmbeddingResult(embedding_result));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<AudioEmbedderResult> ConvertAsyncOutputPackets(
|
||||||
|
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
|
if (!status_or_packets.ok()) {
|
||||||
|
return status_or_packets.status();
|
||||||
|
}
|
||||||
|
return ConvertToEmbeddingResult(
|
||||||
|
status_or_packets.value()[kEmbeddingsName].Get<EmbeddingResult>());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
absl::StatusOr<std::unique_ptr<AudioEmbedder>> AudioEmbedder::Create(
|
||||||
|
std::unique_ptr<AudioEmbedderOptions> options) {
|
||||||
|
auto options_proto = ConvertAudioEmbedderOptionsToProto(options.get());
|
||||||
|
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||||
|
if (options->result_callback) {
|
||||||
|
auto result_callback = options->result_callback;
|
||||||
|
packets_callback =
|
||||||
|
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
|
result_callback(ConvertAsyncOutputPackets(status_or_packets));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return core::AudioTaskApiFactory::Create<AudioEmbedder,
|
||||||
|
proto::AudioEmbedderGraphOptions>(
|
||||||
|
CreateGraphConfig(std::move(options_proto)),
|
||||||
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
|
std::move(packets_callback));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<AudioEmbedderResult>> AudioEmbedder::Embed(
|
||||||
|
Matrix audio_clip, double audio_sample_rate) {
|
||||||
|
return ConvertOutputPackets(ProcessAudioClip(
|
||||||
|
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},
|
||||||
|
{kSampleRateName, MakePacket<double>(audio_sample_rate)}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status AudioEmbedder::EmbedAsync(Matrix audio_block,
|
||||||
|
double audio_sample_rate,
|
||||||
|
int64 timestamp_ms) {
|
||||||
|
MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate));
|
||||||
|
return SendAudioStreamData(
|
||||||
|
{{kAudioStreamName,
|
||||||
|
MakePacket<Matrix>(std::move(audio_block))
|
||||||
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<double> AudioEmbedder::CosineSimilarity(
|
||||||
|
const components::containers::Embedding& u,
|
||||||
|
const components::containers::Embedding& v) {
|
||||||
|
return components::utils::CosineSimilarity(u, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::audio::audio_embedder
|
139
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h
Normal file
139
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::audio::audio_embedder {
|
||||||
|
|
||||||
|
// Alias the shared EmbeddingResult struct as result type.
|
||||||
|
using AudioEmbedderResult =
|
||||||
|
::mediapipe::tasks::components::containers::EmbeddingResult;
|
||||||
|
|
||||||
|
struct AudioEmbedderOptions {
|
||||||
|
// Base options for configuring Task library, such as specifying the TfLite
|
||||||
|
// model file with metadata, accelerator options, op resolver, etc.
|
||||||
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
|
// Options for configuring the embedder behavior, such as score threshold,
|
||||||
|
// number of results, etc.
|
||||||
|
components::processors::EmbedderOptions embedder_options;
|
||||||
|
|
||||||
|
// The running mode of the audio embedder. Default to the audio clips mode.
|
||||||
|
// Audio embedder has two running modes:
|
||||||
|
// 1) The audio clips mode for running embedding on independent audio clips.
|
||||||
|
// 2) The audio stream mode for running embedding on the audio stream,
|
||||||
|
// such as from microphone. In this mode, the "result_callback" below must
|
||||||
|
// be specified to receive the embedding results asynchronously.
|
||||||
|
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||||
|
|
||||||
|
// The user-defined result callback for processing audio stream data.
|
||||||
|
// The result callback should only be specified when the running mode is set
|
||||||
|
// to RunningMode::AUDIO_STREAM.
|
||||||
|
std::function<void(absl::StatusOr<AudioEmbedderResult>)> result_callback =
|
||||||
|
nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Performs embedding extraction on audio clips or audio stream.
|
||||||
|
//
|
||||||
|
// The API expects a TFLite model with TFLite Model Metadata.
|
||||||
|
//
|
||||||
|
// Input tensor:
|
||||||
|
// (kTfLiteFloat32)
|
||||||
|
// - input audio buffer of size `[batch * samples]`.
|
||||||
|
// - batch inference is not supported (`batch` is required to be 1).
|
||||||
|
// - for multi-channel models, the channels need be interleaved.
|
||||||
|
// At least one output tensor with:
|
||||||
|
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||||
|
// - `N` components corresponding to the `N` dimensions of the returned
|
||||||
|
// feature vector for this output layer.
|
||||||
|
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
|
||||||
|
class AudioEmbedder : core::BaseAudioTaskApi {
|
||||||
|
public:
|
||||||
|
using BaseAudioTaskApi::BaseAudioTaskApi;
|
||||||
|
|
||||||
|
// Creates an AudioEmbedder from the provided options. A non-default
|
||||||
|
// OpResolver can be specified in the BaseOptions in order to support custom
|
||||||
|
// Ops or specify a subset of built-in Ops.
|
||||||
|
static absl::StatusOr<std::unique_ptr<AudioEmbedder>> Create(
|
||||||
|
std::unique_ptr<AudioEmbedderOptions> options);
|
||||||
|
|
||||||
|
// Performs embedding extraction on the provided audio clips. Only use this
|
||||||
|
// method when the AudioEmbedder is created with the audio clips running mode.
|
||||||
|
//
|
||||||
|
// The audio clip is represented as a MediaPipe Matrix that has the number of
|
||||||
|
// channels rows and the number of samples per channel columns. The method
|
||||||
|
// accepts audio clips with various length and audio sample rate. It's
|
||||||
|
// required to provide the corresponding audio sample rate along with the
|
||||||
|
// input audio clips.
|
||||||
|
//
|
||||||
|
// The input audio clip may be longer than what the model is able to process
|
||||||
|
// in a single inference. When this occurs, the input audio clip is split into
|
||||||
|
// multiple chunks starting at different timestamps. For this reason, this
|
||||||
|
// function returns a vector of EmbeddingResult objects, each associated
|
||||||
|
// with a timestamp corresponding to the start (in milliseconds) of the chunk
|
||||||
|
// data that was extracted.
|
||||||
|
absl::StatusOr<std::vector<AudioEmbedderResult>> Embed(
|
||||||
|
Matrix audio_clip, double audio_sample_rate);
|
||||||
|
|
||||||
|
// Sends audio stream data to embedder, and the results will be available via
|
||||||
|
// the "result_callback" provided in the AudioEmbedderOptions. Only use this
|
||||||
|
// method when the AudioEmbedder is created with the audio stream running
|
||||||
|
// mode.
|
||||||
|
//
|
||||||
|
// The audio block is represented as a MediaPipe Matrix that has the number
|
||||||
|
// of channels rows and the number of samples per channel columns. The audio
|
||||||
|
// data will be resampled, accumulated, and framed to the proper size for the
|
||||||
|
// underlying model to consume. It's required to provide the corresponding
|
||||||
|
// audio sample rate along with the input audio block as well as a timestamp
|
||||||
|
// (in milliseconds) to indicate the start time of the input audio block. The
|
||||||
|
// timestamps must be monotonically increasing.
|
||||||
|
//
|
||||||
|
// The input audio block may be longer than what the model is able to process
|
||||||
|
// in a single inference. When this occurs, the input audio block is split
|
||||||
|
// into multiple chunks. For this reason, the callback may be called multiple
|
||||||
|
// times (once per chunk) for each call to this function.
|
||||||
|
absl::Status EmbedAsync(Matrix audio_block, double audio_sample_rate,
|
||||||
|
int64 timestamp_ms);
|
||||||
|
|
||||||
|
// Shuts down the AudioEmbedder when all works are done.
|
||||||
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
||||||
|
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||||
|
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||||
|
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||||
|
// 0.
|
||||||
|
//
|
||||||
|
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||||
|
static absl::StatusOr<double> CosineSimilarity(
|
||||||
|
const components::containers::Embedding& u,
|
||||||
|
const components::containers::Embedding& v);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::audio::audio_embedder
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
|
186
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc
Normal file
186
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
|
||||||
|
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::audio::audio_embedder {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
|
||||||
|
constexpr char kAudioTag[] = "AUDIO";
|
||||||
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||||
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
|
||||||
|
// Struct holding the different output streams produced by the audio embedder
|
||||||
|
// graph.
|
||||||
|
struct AudioEmbedderOutputStreams {
|
||||||
|
Source<EmbeddingResult> embeddings;
|
||||||
|
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
|
||||||
|
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
|
||||||
|
const core::ModelResources& model_resources) {
|
||||||
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
|
if (model.subgraphs()->size() != 1) {
|
||||||
|
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||||
|
"Audio embedding tflite models are "
|
||||||
|
"assumed to have a single subgraph.",
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
|
if (primary_subgraph->inputs()->size() != 1) {
|
||||||
|
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||||
|
"Audio embedding tflite models are "
|
||||||
|
"assumed to have a single input.",
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
const auto* input_tensor =
|
||||||
|
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
const auto* audio_tensor_metadata,
|
||||||
|
GetAudioTensorMetadataIfAny(*model_resources.GetMetadataExtractor(), 0));
|
||||||
|
return BuildInputAudioTensorSpecs(*input_tensor, audio_tensor_metadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fills in the AudioToTensorCalculatorOptions based on the AudioTensorSpecs.
|
||||||
|
void ConfigureAudioToTensorCalculator(
|
||||||
|
const AudioTensorSpecs& audio_tensor_specs, bool use_stream_mode,
|
||||||
|
AudioToTensorCalculatorOptions* options) {
|
||||||
|
options->set_num_channels(audio_tensor_specs.num_channels);
|
||||||
|
options->set_num_samples(audio_tensor_specs.num_samples);
|
||||||
|
options->set_target_sample_rate(audio_tensor_specs.sample_rate);
|
||||||
|
options->set_stream_mode(use_stream_mode);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class AudioEmbedderGraph : public core::ModelTaskGraph {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
|
SubgraphContext* sc) override {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
const auto* model_resources,
|
||||||
|
CreateModelResources<proto::AudioEmbedderGraphOptions>(sc));
|
||||||
|
Graph graph;
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_streams,
|
||||||
|
BuildAudioEmbeddingTask(
|
||||||
|
sc->Options<proto::AudioEmbedderGraphOptions>(), *model_resources,
|
||||||
|
graph[Input<Matrix>(kAudioTag)],
|
||||||
|
absl::make_optional(graph[Input<double>(kSampleRateTag)]), graph));
|
||||||
|
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
|
output_streams.timestamped_embeddings >>
|
||||||
|
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::StatusOr<AudioEmbedderOutputStreams> BuildAudioEmbeddingTask(
|
||||||
|
const proto::AudioEmbedderGraphOptions& task_options,
|
||||||
|
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||||
|
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||||
|
const bool use_stream_mode = task_options.base_options().use_stream_mode();
|
||||||
|
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
||||||
|
// Checks that metadata is available.
|
||||||
|
if (metadata_extractor->GetModelMetadata() == nullptr ||
|
||||||
|
metadata_extractor->GetModelMetadata()->subgraph_metadata() ==
|
||||||
|
nullptr) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
"Audio embedder models require TFLite Model Metadata but none was "
|
||||||
|
"found",
|
||||||
|
MediaPipeTasksStatus::kMetadataNotFoundError);
|
||||||
|
}
|
||||||
|
// Adds AudioToTensorCalculator and connects it to the graph input streams.
|
||||||
|
ASSIGN_OR_RETURN(auto audio_tensor_specs,
|
||||||
|
BuildPreprocessingSpecs(model_resources));
|
||||||
|
auto& audio_to_tensor = graph.AddNode("AudioToTensorCalculator");
|
||||||
|
ConfigureAudioToTensorCalculator(
|
||||||
|
audio_tensor_specs, use_stream_mode,
|
||||||
|
&audio_to_tensor.GetOptions<AudioToTensorCalculatorOptions>());
|
||||||
|
audio_in >> audio_to_tensor.In(kAudioTag);
|
||||||
|
if (sample_rate_in.has_value()) {
|
||||||
|
sample_rate_in.value() >> audio_to_tensor.In(kSampleRateTag);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds inference subgraph and connects its input stream to the output
|
||||||
|
// tensors produced by the AudioToTensorCalculator.
|
||||||
|
auto& inference = AddInference(
|
||||||
|
model_resources, task_options.base_options().acceleration(), graph);
|
||||||
|
audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||||
|
// Adds postprocessing calculators and connects its input stream to the
|
||||||
|
// inference results.
|
||||||
|
auto& postprocessing = graph.AddNode(
|
||||||
|
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||||
|
model_resources, task_options.embedder_options(),
|
||||||
|
&postprocessing.GetOptions<components::processors::proto::
|
||||||
|
EmbeddingPostprocessingGraphOptions>()));
|
||||||
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
// Time aggregation is only needed for performing audio embedding on
|
||||||
|
// audio files. Disables timestamp aggregation by not connecting the
|
||||||
|
// "TIMESTAMPS" streams.
|
||||||
|
if (!use_stream_mode) {
|
||||||
|
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Outputs both streams as graph output streams/
|
||||||
|
return AudioEmbedderOutputStreams{
|
||||||
|
/*embeddings=*/postprocessing[Output<EmbeddingResult>(kEmbeddingsTag)],
|
||||||
|
/*timestamped_embeddings=*/
|
||||||
|
postprocessing[Output<std::vector<EmbeddingResult>>(
|
||||||
|
kTimestampedEmbeddingsTag)],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
::mediapipe::tasks::audio::audio_embedder::AudioEmbedderGraph);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::audio::audio_embedder
|
289
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc
Normal file
289
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc
Normal file
|
@ -0,0 +1,289 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <new>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/cord.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace audio {
|
||||||
|
namespace audio_embedder {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::absl::StatusOr;
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/audio";
|
||||||
|
constexpr char kModelWithMetadata[] = "yamnet_embedding_metadata.tflite";
|
||||||
|
constexpr char k16kTestWavFilename[] = "speech_16000_hz_mono.wav";
|
||||||
|
constexpr char k48kTestWavFilename[] = "speech_48000_hz_mono.wav";
|
||||||
|
constexpr char k16kTestWavForTwoHeadsFilename[] = "two_heads_16000_hz_mono.wav";
|
||||||
|
constexpr float kSpeechSimilarities[] = {0.985359, 0.994349, 0.993227, 0.996658,
|
||||||
|
0.996384};
|
||||||
|
constexpr int kMilliSecondsPerSecond = 1000;
|
||||||
|
constexpr int kYamnetNumOfAudioSamples = 15600;
|
||||||
|
constexpr int kYamnetAudioSampleRate = 16000;
|
||||||
|
|
||||||
|
Matrix GetAudioData(absl::string_view filename) {
|
||||||
|
std::string wav_file_path = JoinPath("./", kTestDataDirectory, filename);
|
||||||
|
int buffer_size;
|
||||||
|
auto audio_data = internal::ReadWavFile(wav_file_path, &buffer_size);
|
||||||
|
Eigen::Map<Matrix> matrix_mapping(audio_data->get(), 1, buffer_size);
|
||||||
|
return matrix_mapping.matrix();
|
||||||
|
}
|
||||||
|
|
||||||
|
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
|
auto audio_embedder =
|
||||||
|
AudioEmbedder::Create(std::make_unique<AudioEmbedderOptions>());
|
||||||
|
|
||||||
|
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
audio_embedder.status().message(),
|
||||||
|
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||||
|
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||||
|
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||||
|
AudioEmbedder::Create(std::move(options)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInAudioClipsMode) {
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||||
|
options->result_callback = [](absl::StatusOr<AudioEmbedderResult>) {};
|
||||||
|
|
||||||
|
auto audio_embedder = AudioEmbedder::Create(std::move(options));
|
||||||
|
|
||||||
|
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
audio_embedder.status().message(),
|
||||||
|
HasSubstr("a user-defined result callback shouldn't be provided"));
|
||||||
|
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
|
|
||||||
|
auto audio_embedder = AudioEmbedder::Create(std::move(options));
|
||||||
|
|
||||||
|
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(audio_embedder.status().message(),
|
||||||
|
HasSubstr("a user-defined result callback must be provided"));
|
||||||
|
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmbedTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||||
|
AudioEmbedder::Create(std::move(options)));
|
||||||
|
Matrix silent_data(1, kYamnetNumOfAudioSamples);
|
||||||
|
silent_data.setZero();
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result, audio_embedder->Embed(silent_data, kYamnetAudioSampleRate));
|
||||||
|
EXPECT_EQ(result.size(), 1);
|
||||||
|
EXPECT_EQ(result[0].embeddings[0].float_embedding.size(), 1024);
|
||||||
|
constexpr float kValueDiffTolerance = 3e-6;
|
||||||
|
EXPECT_NEAR(result[0].embeddings[0].float_embedding[0], 2.07613f,
|
||||||
|
kValueDiffTolerance);
|
||||||
|
EXPECT_NEAR(result[0].embeddings[0].float_embedding[1], 0.392721f,
|
||||||
|
kValueDiffTolerance);
|
||||||
|
EXPECT_NEAR(result[0].embeddings[0].float_embedding[2], 0.543622f,
|
||||||
|
kValueDiffTolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbedTest, SucceedsWithSameAudioAtDifferentSampleRates) {
|
||||||
|
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
|
||||||
|
auto audio_buffer2 = GetAudioData(k48kTestWavFilename);
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||||
|
AudioEmbedder::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto result1,
|
||||||
|
audio_embedder->Embed(audio_buffer1, 16000));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto result2,
|
||||||
|
audio_embedder->Embed(audio_buffer2, 48000));
|
||||||
|
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
|
||||||
|
ASSERT_EQ(result1.size(), expected_size);
|
||||||
|
ASSERT_EQ(result2.size(), expected_size);
|
||||||
|
for (int i = 0; i < expected_size; ++i) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||||
|
result1[i].embeddings[0],
|
||||||
|
result2[i].embeddings[0]));
|
||||||
|
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
|
||||||
|
}
|
||||||
|
MP_EXPECT_OK(audio_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
|
||||||
|
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
|
||||||
|
auto audio_buffer2 = GetAudioData(k16kTestWavForTwoHeadsFilename);
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||||
|
AudioEmbedder::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result1,
|
||||||
|
audio_embedder->Embed(audio_buffer1, kYamnetAudioSampleRate));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result2,
|
||||||
|
audio_embedder->Embed(audio_buffer2, kYamnetAudioSampleRate));
|
||||||
|
ASSERT_EQ(result1.size(), 5);
|
||||||
|
ASSERT_EQ(result2.size(), 1);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||||
|
result1[0].embeddings[0],
|
||||||
|
result2[0].embeddings[0]));
|
||||||
|
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
|
||||||
|
MP_EXPECT_OK(audio_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmbedAsyncTest : public tflite_shims::testing::Test {
|
||||||
|
protected:
|
||||||
|
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
|
||||||
|
int sample_rate_hz,
|
||||||
|
std::vector<AudioEmbedderResult>* result) {
|
||||||
|
auto audio_buffer = GetAudioData(audio_file_name);
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
|
options->result_callback =
|
||||||
|
[result](absl::StatusOr<AudioEmbedderResult> status_or_result) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(result->emplace_back(), status_or_result);
|
||||||
|
};
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||||
|
AudioEmbedder::Create(std::move(options)));
|
||||||
|
int start_col = 0;
|
||||||
|
static unsigned int rseed = 0;
|
||||||
|
while (start_col < audio_buffer.cols()) {
|
||||||
|
int num_samples = std::min(
|
||||||
|
(int)(audio_buffer.cols() - start_col),
|
||||||
|
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * sample_rate_hz /
|
||||||
|
kYamnetAudioSampleRate);
|
||||||
|
MP_ASSERT_OK(audio_embedder->EmbedAsync(
|
||||||
|
audio_buffer.block(0, start_col, 1, num_samples), sample_rate_hz,
|
||||||
|
start_col * kMilliSecondsPerSecond / sample_rate_hz));
|
||||||
|
start_col += num_samples;
|
||||||
|
}
|
||||||
|
MP_ASSERT_OK(audio_embedder->Close());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(EmbedAsyncTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
|
options->result_callback =
|
||||||
|
[](absl::StatusOr<AudioEmbedderResult> status_or_result) { return; };
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||||
|
AudioEmbedder::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK(audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
|
||||||
|
kYamnetAudioSampleRate, 100));
|
||||||
|
auto status = audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
|
||||||
|
kYamnetAudioSampleRate, 0);
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
HasSubstr("timestamp must be monotonically increasing"));
|
||||||
|
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kRunnerInvalidTimestampError))));
|
||||||
|
MP_ASSERT_OK(audio_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbedAsyncTest, SucceedsWithSameAudioAtDifferentSampleRates) {
|
||||||
|
std::vector<AudioEmbedderResult> result1;
|
||||||
|
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
|
||||||
|
std::vector<AudioEmbedderResult> result2;
|
||||||
|
RunAudioEmbedderInStreamMode(k48kTestWavFilename, 48000, &result2);
|
||||||
|
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
|
||||||
|
ASSERT_EQ(result1.size(), expected_size);
|
||||||
|
ASSERT_EQ(result2.size(), expected_size);
|
||||||
|
for (int i = 0; i < expected_size; ++i) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||||
|
result1[i].embeddings[0],
|
||||||
|
result2[i].embeddings[0]));
|
||||||
|
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbedAsyncTest, SucceedsWithDifferentAudios) {
|
||||||
|
std::vector<AudioEmbedderResult> result1;
|
||||||
|
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
|
||||||
|
std::vector<AudioEmbedderResult> result2;
|
||||||
|
RunAudioEmbedderInStreamMode(k16kTestWavForTwoHeadsFilename, 16000, &result2);
|
||||||
|
ASSERT_EQ(result1.size(), 5);
|
||||||
|
ASSERT_EQ(result2.size(), 1);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||||
|
result1[0].embeddings[0],
|
||||||
|
result2[0].embeddings[0]));
|
||||||
|
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace audio_embedder
|
||||||
|
} // namespace audio
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
30
mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD
Normal file
30
mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "audio_embedder_graph_options_proto",
|
||||||
|
srcs = ["audio_embedder_graph_options.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,39 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe.tasks.audio.audio_embedder.proto;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.audio.audioembedder.proto";
|
||||||
|
option java_outer_classname = "AudioEmbedderGraphOptionsProto";
|
||||||
|
|
||||||
|
message AudioEmbedderGraphOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional AudioEmbedderGraphOptions ext = 487277289;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
|
// model file with metadata, accelerator options, etc.
|
||||||
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
||||||
|
// Options for configuring the embedder behavior, such as normalization or
|
||||||
|
// quantization.
|
||||||
|
optional components.processors.proto.EmbedderOptions embedder_options = 2;
|
||||||
|
}
|
2
mediapipe/tasks/testdata/audio/BUILD
vendored
2
mediapipe/tasks/testdata/audio/BUILD
vendored
|
@ -30,6 +30,7 @@ mediapipe_files(srcs = [
|
||||||
"two_heads_16000_hz_mono.wav",
|
"two_heads_16000_hz_mono.wav",
|
||||||
"two_heads_44100_hz_mono.wav",
|
"two_heads_44100_hz_mono.wav",
|
||||||
"yamnet_audio_classifier_with_metadata.tflite",
|
"yamnet_audio_classifier_with_metadata.tflite",
|
||||||
|
"yamnet_embedding_metadata.tflite",
|
||||||
])
|
])
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -38,6 +39,7 @@ filegroup(
|
||||||
"model_without_metadata.tflite",
|
"model_without_metadata.tflite",
|
||||||
"two_heads.tflite",
|
"two_heads.tflite",
|
||||||
"yamnet_audio_classifier_with_metadata.tflite",
|
"yamnet_audio_classifier_with_metadata.tflite",
|
||||||
|
"yamnet_embedding_metadata.tflite",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -982,6 +982,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_yamnet_embedding_metadata_tflite",
|
||||||
|
sha256 = "7baa72708e3919bae5a5dc78d932847bc28008af14febd083eff62d28af9c72a",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
|
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
|
||||||
sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3",
|
sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user