MediaPipe Tasks Audio embedder C++ API.

PiperOrigin-RevId: 488273381
This commit is contained in:
Jiuqiang Tang 2022-11-13 23:08:25 -08:00 committed by Copybara-Service
parent 0dfa91a166
commit 6c0ca947de
9 changed files with 926 additions and 0 deletions

View File

@ -0,0 +1,79 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "audio_embedder_graph",
srcs = ["audio_embedder_graph.cc"],
deps = [
"//mediapipe/calculators/audio:time_series_framer_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
"//mediapipe/calculators/tensor:audio_to_tensor_calculator",
"//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:optional",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
cc_library(
name = "audio_embedder",
srcs = ["audio_embedder.cc"],
hdrs = ["audio_embedder.h"],
deps = [
":audio_embedder_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)
# TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build

View File

@ -0,0 +1,156 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h"
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe::tasks::audio::audio_embedder {
namespace {
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kEmbeddingsName[] = "embeddings_out";
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
// Creates a MediaPipe graph config that only contains a single subgraph node of
// type "AudioEmbedderGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<proto::AudioEmbedderGraphOptions> options_proto) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kSubgraphTypeName);
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
subgraph.In(kSampleRateTag);
subgraph.GetOptions<proto::AudioEmbedderGraphOptions>().Swap(
options_proto.get());
subgraph.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
graph.Out(kEmbeddingsTag);
subgraph.Out(kTimestampedEmbeddingsTag).SetName(kTimestampedEmbeddingsName) >>
graph.Out(kTimestampedEmbeddingsTag);
return graph.GetConfig();
}
// Converts the user-facing AudioEmbedderOptions struct to the internal
// AudioEmbedderGraphOptions proto.
std::unique_ptr<proto::AudioEmbedderGraphOptions>
ConvertAudioEmbedderOptionsToProto(AudioEmbedderOptions* options) {
auto options_proto = std::make_unique<proto::AudioEmbedderGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode == core::RunningMode::AUDIO_STREAM);
auto embedder_options_proto =
std::make_unique<components::processors::proto::EmbedderOptions>(
components::processors::ConvertEmbedderOptionsToProto(
&(options->embedder_options)));
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
return options_proto;
}
absl::StatusOr<std::vector<AudioEmbedderResult>> ConvertOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
return status_or_packets.status();
}
auto embedding_results = status_or_packets.value()[kTimestampedEmbeddingsName]
.Get<std::vector<EmbeddingResult>>();
std::vector<AudioEmbedderResult> results;
results.reserve(embedding_results.size());
for (const auto& embedding_result : embedding_results) {
results.emplace_back(ConvertToEmbeddingResult(embedding_result));
}
return results;
}
absl::StatusOr<AudioEmbedderResult> ConvertAsyncOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
return status_or_packets.status();
}
return ConvertToEmbeddingResult(
status_or_packets.value()[kEmbeddingsName].Get<EmbeddingResult>());
}
} // namespace
/* static */
absl::StatusOr<std::unique_ptr<AudioEmbedder>> AudioEmbedder::Create(
std::unique_ptr<AudioEmbedderOptions> options) {
auto options_proto = ConvertAudioEmbedderOptionsToProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
result_callback(ConvertAsyncOutputPackets(status_or_packets));
};
}
return core::AudioTaskApiFactory::Create<AudioEmbedder,
proto::AudioEmbedderGraphOptions>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
}
absl::StatusOr<std::vector<AudioEmbedderResult>> AudioEmbedder::Embed(
Matrix audio_clip, double audio_sample_rate) {
return ConvertOutputPackets(ProcessAudioClip(
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},
{kSampleRateName, MakePacket<double>(audio_sample_rate)}}));
}
absl::Status AudioEmbedder::EmbedAsync(Matrix audio_block,
double audio_sample_rate,
int64 timestamp_ms) {
MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate));
return SendAudioStreamData(
{{kAudioStreamName,
MakePacket<Matrix>(std::move(audio_block))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
absl::StatusOr<double> AudioEmbedder::CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v);
}
} // namespace mediapipe::tasks::audio::audio_embedder

View File

@ -0,0 +1,139 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
#define MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe::tasks::audio::audio_embedder {
// Alias the shared EmbeddingResult struct as result type.
using AudioEmbedderResult =
::mediapipe::tasks::components::containers::EmbeddingResult;
struct AudioEmbedderOptions {
// Base options for configuring Task library, such as specifying the TfLite
// model file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// Options for configuring the embedder behavior, such as score threshold,
// number of results, etc.
components::processors::EmbedderOptions embedder_options;
// The running mode of the audio embedder. Default to the audio clips mode.
// Audio embedder has two running modes:
// 1) The audio clips mode for running embedding on independent audio clips.
// 2) The audio stream mode for running embedding on the audio stream,
// such as from microphone. In this mode, the "result_callback" below must
// be specified to receive the embedding results asynchronously.
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
// The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM.
std::function<void(absl::StatusOr<AudioEmbedderResult>)> result_callback =
nullptr;
};
// Performs embedding extraction on audio clips or audio stream.
//
// The API expects a TFLite model with TFLite Model Metadata.
//
// Input tensor:
// (kTfLiteFloat32)
// - input audio buffer of size `[batch * samples]`.
// - batch inference is not supported (`batch` is required to be 1).
// - for multi-channel models, the channels need be interleaved.
// At least one output tensor with:
// (kTfLiteUInt8/kTfLiteFloat32)
// - `N` components corresponding to the `N` dimensions of the returned
// feature vector for this output layer.
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
class AudioEmbedder : core::BaseAudioTaskApi {
public:
using BaseAudioTaskApi::BaseAudioTaskApi;
// Creates an AudioEmbedder from the provided options. A non-default
// OpResolver can be specified in the BaseOptions in order to support custom
// Ops or specify a subset of built-in Ops.
static absl::StatusOr<std::unique_ptr<AudioEmbedder>> Create(
std::unique_ptr<AudioEmbedderOptions> options);
// Performs embedding extraction on the provided audio clips. Only use this
// method when the AudioEmbedder is created with the audio clips running mode.
//
// The audio clip is represented as a MediaPipe Matrix that has the number of
// channels rows and the number of samples per channel columns. The method
// accepts audio clips with various length and audio sample rate. It's
// required to provide the corresponding audio sample rate along with the
// input audio clips.
//
// The input audio clip may be longer than what the model is able to process
// in a single inference. When this occurs, the input audio clip is split into
// multiple chunks starting at different timestamps. For this reason, this
// function returns a vector of EmbeddingResult objects, each associated
// with a timestamp corresponding to the start (in milliseconds) of the chunk
// data that was extracted.
absl::StatusOr<std::vector<AudioEmbedderResult>> Embed(
Matrix audio_clip, double audio_sample_rate);
// Sends audio stream data to embedder, and the results will be available via
// the "result_callback" provided in the AudioEmbedderOptions. Only use this
// method when the AudioEmbedder is created with the audio stream running
// mode.
//
// The audio block is represented as a MediaPipe Matrix that has the number
// of channels rows and the number of samples per channel columns. The audio
// data will be resampled, accumulated, and framed to the proper size for the
// underlying model to consume. It's required to provide the corresponding
// audio sample rate along with the input audio block as well as a timestamp
// (in milliseconds) to indicate the start time of the input audio block. The
// timestamps must be monotonically increasing.
//
// The input audio block may be longer than what the model is able to process
// in a single inference. When this occurs, the input audio block is split
// into multiple chunks. For this reason, the callback may be called multiple
// times (once per chunk) for each call to this function.
absl::Status EmbedAsync(Matrix audio_block, double audio_sample_rate,
int64 timestamp_ms);
// Shuts down the AudioEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v);
};
} // namespace mediapipe::tasks::audio::audio_embedder
#endif // MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_

View File

@ -0,0 +1,186 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe::tasks::audio::audio_embedder {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
constexpr char kAudioTag[] = "AUDIO";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Struct holding the different output streams produced by the audio embedder
// graph.
struct AudioEmbedderOutputStreams {
Source<EmbeddingResult> embeddings;
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
};
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"Audio embedding tflite models are "
"assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->inputs()->size() != 1) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"Audio embedding tflite models are "
"assumed to have a single input.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* input_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
ASSIGN_OR_RETURN(
const auto* audio_tensor_metadata,
GetAudioTensorMetadataIfAny(*model_resources.GetMetadataExtractor(), 0));
return BuildInputAudioTensorSpecs(*input_tensor, audio_tensor_metadata);
}
// Fills in the AudioToTensorCalculatorOptions based on the AudioTensorSpecs.
void ConfigureAudioToTensorCalculator(
const AudioTensorSpecs& audio_tensor_specs, bool use_stream_mode,
AudioToTensorCalculatorOptions* options) {
options->set_num_channels(audio_tensor_specs.num_channels);
options->set_num_samples(audio_tensor_specs.num_samples);
options->set_target_sample_rate(audio_tensor_specs.sample_rate);
options->set_stream_mode(use_stream_mode);
}
} // namespace
class AudioEmbedderGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<proto::AudioEmbedderGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto output_streams,
BuildAudioEmbeddingTask(
sc->Options<proto::AudioEmbedderGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)],
absl::make_optional(graph[Input<double>(kSampleRateTag)]), graph));
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
output_streams.timestamped_embeddings >>
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<AudioEmbedderOutputStreams> BuildAudioEmbeddingTask(
const proto::AudioEmbedderGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
const bool use_stream_mode = task_options.base_options().use_stream_mode();
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
// Checks that metadata is available.
if (metadata_extractor->GetModelMetadata() == nullptr ||
metadata_extractor->GetModelMetadata()->subgraph_metadata() ==
nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Audio embedder models require TFLite Model Metadata but none was "
"found",
MediaPipeTasksStatus::kMetadataNotFoundError);
}
// Adds AudioToTensorCalculator and connects it to the graph input streams.
ASSIGN_OR_RETURN(auto audio_tensor_specs,
BuildPreprocessingSpecs(model_resources));
auto& audio_to_tensor = graph.AddNode("AudioToTensorCalculator");
ConfigureAudioToTensorCalculator(
audio_tensor_specs, use_stream_mode,
&audio_to_tensor.GetOptions<AudioToTensorCalculatorOptions>());
audio_in >> audio_to_tensor.In(kAudioTag);
if (sample_rate_in.has_value()) {
sample_rate_in.value() >> audio_to_tensor.In(kSampleRateTag);
}
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the AudioToTensorCalculator.
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag);
// Adds postprocessing calculators and connects its input stream to the
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio embedding on
// audio files. Disables timestamp aggregation by not connecting the
// "TIMESTAMPS" streams.
if (!use_stream_mode) {
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
}
// Outputs both streams as graph output streams/
return AudioEmbedderOutputStreams{
/*embeddings=*/postprocessing[Output<EmbeddingResult>(kEmbeddingsTag)],
/*timestamped_embeddings=*/
postprocessing[Output<std::vector<EmbeddingResult>>(
kTimestampedEmbeddingsTag)],
};
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::audio::audio_embedder::AudioEmbedderGraph);
} // namespace mediapipe::tasks::audio::audio_embedder

View File

@ -0,0 +1,289 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h"
#include <algorithm>
#include <memory>
#include <new>
#include <string>
#include <utility>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace audio {
namespace audio_embedder {
namespace {
using ::absl::StatusOr;
using ::mediapipe::file::JoinPath;
using ::testing::HasSubstr;
using ::testing::Optional;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/audio";
constexpr char kModelWithMetadata[] = "yamnet_embedding_metadata.tflite";
constexpr char k16kTestWavFilename[] = "speech_16000_hz_mono.wav";
constexpr char k48kTestWavFilename[] = "speech_48000_hz_mono.wav";
constexpr char k16kTestWavForTwoHeadsFilename[] = "two_heads_16000_hz_mono.wav";
constexpr float kSpeechSimilarities[] = {0.985359, 0.994349, 0.993227, 0.996658,
0.996384};
constexpr int kMilliSecondsPerSecond = 1000;
constexpr int kYamnetNumOfAudioSamples = 15600;
constexpr int kYamnetAudioSampleRate = 16000;
Matrix GetAudioData(absl::string_view filename) {
std::string wav_file_path = JoinPath("./", kTestDataDirectory, filename);
int buffer_size;
auto audio_data = internal::ReadWavFile(wav_file_path, &buffer_size);
Eigen::Map<Matrix> matrix_mapping(audio_data->get(), 1, buffer_size);
return matrix_mapping.matrix();
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
auto audio_embedder =
AudioEmbedder::Create(std::make_unique<AudioEmbedderOptions>());
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
audio_embedder.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
}
TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInAudioClipsMode) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
options->result_callback = [](absl::StatusOr<AudioEmbedderResult>) {};
auto audio_embedder = AudioEmbedder::Create(std::move(options));
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
audio_embedder.status().message(),
HasSubstr("a user-defined result callback shouldn't be provided"));
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
auto audio_embedder = AudioEmbedder::Create(std::move(options));
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(audio_embedder.status().message(),
HasSubstr("a user-defined result callback must be provided"));
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class EmbedTest : public tflite_shims::testing::Test {};
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
Matrix silent_data(1, kYamnetNumOfAudioSamples);
silent_data.setZero();
MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_embedder->Embed(silent_data, kYamnetAudioSampleRate));
EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result[0].embeddings[0].float_embedding.size(), 1024);
constexpr float kValueDiffTolerance = 3e-6;
EXPECT_NEAR(result[0].embeddings[0].float_embedding[0], 2.07613f,
kValueDiffTolerance);
EXPECT_NEAR(result[0].embeddings[0].float_embedding[1], 0.392721f,
kValueDiffTolerance);
EXPECT_NEAR(result[0].embeddings[0].float_embedding[2], 0.543622f,
kValueDiffTolerance);
}
TEST_F(EmbedTest, SucceedsWithSameAudioAtDifferentSampleRates) {
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
auto audio_buffer2 = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto result1,
audio_embedder->Embed(audio_buffer1, 16000));
MP_ASSERT_OK_AND_ASSIGN(auto result2,
audio_embedder->Embed(audio_buffer2, 48000));
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
ASSERT_EQ(result1.size(), expected_size);
ASSERT_EQ(result2.size(), expected_size);
for (int i = 0; i < expected_size; ++i) {
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[i].embeddings[0],
result2[i].embeddings[0]));
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
}
MP_EXPECT_OK(audio_embedder->Close());
}
TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
auto audio_buffer2 = GetAudioData(k16kTestWavForTwoHeadsFilename);
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result1,
audio_embedder->Embed(audio_buffer1, kYamnetAudioSampleRate));
MP_ASSERT_OK_AND_ASSIGN(
auto result2,
audio_embedder->Embed(audio_buffer2, kYamnetAudioSampleRate));
ASSERT_EQ(result1.size(), 5);
ASSERT_EQ(result2.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[0].embeddings[0],
result2[0].embeddings[0]));
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
MP_EXPECT_OK(audio_embedder->Close());
}
class EmbedAsyncTest : public tflite_shims::testing::Test {
protected:
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
int sample_rate_hz,
std::vector<AudioEmbedderResult>* result) {
auto audio_buffer = GetAudioData(audio_file_name);
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback =
[result](absl::StatusOr<AudioEmbedderResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(result->emplace_back(), status_or_result);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
int start_col = 0;
static unsigned int rseed = 0;
while (start_col < audio_buffer.cols()) {
int num_samples = std::min(
(int)(audio_buffer.cols() - start_col),
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * sample_rate_hz /
kYamnetAudioSampleRate);
MP_ASSERT_OK(audio_embedder->EmbedAsync(
audio_buffer.block(0, start_col, 1, num_samples), sample_rate_hz,
start_col * kMilliSecondsPerSecond / sample_rate_hz));
start_col += num_samples;
}
MP_ASSERT_OK(audio_embedder->Close());
}
};
TEST_F(EmbedAsyncTest, FailsWithOutOfOrderInputTimestamps) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback =
[](absl::StatusOr<AudioEmbedderResult> status_or_result) { return; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
MP_ASSERT_OK(audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
kYamnetAudioSampleRate, 100));
auto status = audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
kYamnetAudioSampleRate, 0);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
HasSubstr("timestamp must be monotonically increasing"));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInvalidTimestampError))));
MP_ASSERT_OK(audio_embedder->Close());
}
TEST_F(EmbedAsyncTest, SucceedsWithSameAudioAtDifferentSampleRates) {
std::vector<AudioEmbedderResult> result1;
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
std::vector<AudioEmbedderResult> result2;
RunAudioEmbedderInStreamMode(k48kTestWavFilename, 48000, &result2);
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
ASSERT_EQ(result1.size(), expected_size);
ASSERT_EQ(result2.size(), expected_size);
for (int i = 0; i < expected_size; ++i) {
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[i].embeddings[0],
result2[i].embeddings[0]));
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
}
}
TEST_F(EmbedAsyncTest, SucceedsWithDifferentAudios) {
std::vector<AudioEmbedderResult> result1;
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
std::vector<AudioEmbedderResult> result2;
RunAudioEmbedderInStreamMode(k16kTestWavForTwoHeadsFilename, 16000, &result2);
ASSERT_EQ(result1.size(), 5);
ASSERT_EQ(result2.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
result1[0].embeddings[0],
result2[0].embeddings[0]));
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
}
} // namespace
} // namespace audio_embedder
} // namespace audio
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "audio_embedder_graph_options_proto",
srcs = ["audio_embedder_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,39 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.audio.audio_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.audio.audioembedder.proto";
option java_outer_classname = "AudioEmbedderGraphOptionsProto";
message AudioEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional AudioEmbedderGraphOptions ext = 487277289;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for configuring the embedder behavior, such as normalization or
// quantization.
optional components.processors.proto.EmbedderOptions embedder_options = 2;
}

View File

@ -30,6 +30,7 @@ mediapipe_files(srcs = [
"two_heads_16000_hz_mono.wav", "two_heads_16000_hz_mono.wav",
"two_heads_44100_hz_mono.wav", "two_heads_44100_hz_mono.wav",
"yamnet_audio_classifier_with_metadata.tflite", "yamnet_audio_classifier_with_metadata.tflite",
"yamnet_embedding_metadata.tflite",
]) ])
filegroup( filegroup(
@ -38,6 +39,7 @@ filegroup(
"model_without_metadata.tflite", "model_without_metadata.tflite",
"two_heads.tflite", "two_heads.tflite",
"yamnet_audio_classifier_with_metadata.tflite", "yamnet_audio_classifier_with_metadata.tflite",
"yamnet_embedding_metadata.tflite",
], ],
) )

View File

@ -982,6 +982,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"], urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"],
) )
http_file(
name = "com_google_mediapipe_yamnet_embedding_metadata_tflite",
sha256 = "7baa72708e3919bae5a5dc78d932847bc28008af14febd083eff62d28af9c72a",
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"],
)
http_file( http_file(
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3", sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3",