From 6c0ca947dea30f2beb3cb11a404dc14af7d4390a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Sun, 13 Nov 2022 23:08:25 -0800 Subject: [PATCH] MediaPipe Tasks Audio embedder C++ API. PiperOrigin-RevId: 488273381 --- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 79 +++++ .../cc/audio/audio_embedder/audio_embedder.cc | 156 ++++++++++ .../cc/audio/audio_embedder/audio_embedder.h | 139 +++++++++ .../audio_embedder/audio_embedder_graph.cc | 186 +++++++++++ .../audio_embedder/audio_embedder_test.cc | 289 ++++++++++++++++++ .../tasks/cc/audio/audio_embedder/proto/BUILD | 30 ++ .../proto/audio_embedder_graph_options.proto | 39 +++ mediapipe/tasks/testdata/audio/BUILD | 2 + third_party/external_files.bzl | 6 + 9 files changed, 926 insertions(+) create mode 100644 mediapipe/tasks/cc/audio/audio_embedder/BUILD create mode 100644 mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc create mode 100644 mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h create mode 100644 mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc create mode 100644 mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc create mode 100644 mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD create mode 100644 mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD new file mode 100644 index 000000000..b982ef39a --- /dev/null +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -0,0 +1,79 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "audio_embedder_graph", + srcs = ["audio_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/audio:time_series_framer_calculator", + "//mediapipe/calculators/core:constant_side_packet_calculator", + "//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto", + "//mediapipe/calculators/core:side_packet_to_stream_calculator", + "//mediapipe/calculators/tensor:audio_to_tensor_calculator", + "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +cc_library( + name = "audio_embedder", + srcs = ["audio_embedder.cc"], + hdrs = ["audio_embedder.h"], + deps = [ + ":audio_embedder_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + +# TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc new file mode 100644 index 000000000..1c4a524d6 --- /dev/null +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc @@ -0,0 +1,156 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h" +#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "tensorflow/lite/core/api/op_resolver.h" + +namespace mediapipe::tasks::audio::audio_embedder { +namespace { +using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +constexpr char kAudioStreamName[] = "audio_in"; +constexpr char kAudioTag[] = "AUDIO"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; +constexpr char kEmbeddingsName[] = "embeddings_out"; +constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out"; +constexpr char kSampleRateName[] = "sample_rate_in"; +constexpr char kSampleRateTag[] = "SAMPLE_RATE"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// type "AudioEmbedderGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options_proto) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kSubgraphTypeName); + graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag); + graph.In(kSampleRateTag).SetName(kSampleRateName) >> + subgraph.In(kSampleRateTag); + subgraph.GetOptions().Swap( + options_proto.get()); + subgraph.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >> + graph.Out(kEmbeddingsTag); + subgraph.Out(kTimestampedEmbeddingsTag).SetName(kTimestampedEmbeddingsName) >> + graph.Out(kTimestampedEmbeddingsTag); + return graph.GetConfig(); +} + +// Converts the user-facing AudioEmbedderOptions struct to the internal +// AudioEmbedderGraphOptions proto. +std::unique_ptr +ConvertAudioEmbedderOptionsToProto(AudioEmbedderOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode == core::RunningMode::AUDIO_STREAM); + auto embedder_options_proto = + std::make_unique( + components::processors::ConvertEmbedderOptionsToProto( + &(options->embedder_options))); + options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); + return options_proto; +} + +absl::StatusOr> ConvertOutputPackets( + absl::StatusOr status_or_packets) { + if (!status_or_packets.ok()) { + return status_or_packets.status(); + } + auto embedding_results = status_or_packets.value()[kTimestampedEmbeddingsName] + .Get>(); + std::vector results; + results.reserve(embedding_results.size()); + for (const auto& embedding_result : embedding_results) { + results.emplace_back(ConvertToEmbeddingResult(embedding_result)); + } + return results; +} + +absl::StatusOr ConvertAsyncOutputPackets( + absl::StatusOr status_or_packets) { + if (!status_or_packets.ok()) { + return status_or_packets.status(); + } + return ConvertToEmbeddingResult( + status_or_packets.value()[kEmbeddingsName].Get()); +} +} // namespace + +/* static */ +absl::StatusOr> AudioEmbedder::Create( + std::unique_ptr options) { + auto options_proto = ConvertAudioEmbedderOptionsToProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = + [=](absl::StatusOr status_or_packets) { + result_callback(ConvertAsyncOutputPackets(status_or_packets)); + }; + } + return core::AudioTaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr> AudioEmbedder::Embed( + Matrix audio_clip, double audio_sample_rate) { + return ConvertOutputPackets(ProcessAudioClip( + {{kAudioStreamName, MakePacket(std::move(audio_clip))}, + {kSampleRateName, MakePacket(audio_sample_rate)}})); +} + +absl::Status AudioEmbedder::EmbedAsync(Matrix audio_block, + double audio_sample_rate, + int64 timestamp_ms) { + MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate)); + return SendAudioStreamData( + {{kAudioStreamName, + MakePacket(std::move(audio_block)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +absl::StatusOr AudioEmbedder::CosineSimilarity( + const components::containers::Embedding& u, + const components::containers::Embedding& v) { + return components::utils::CosineSimilarity(u, v); +} + +} // namespace mediapipe::tasks::audio::audio_embedder diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h new file mode 100644 index 000000000..4e7e20530 --- /dev/null +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h @@ -0,0 +1,139 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_ +#define MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" +#include "mediapipe/tasks/cc/audio/core/running_mode.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" + +namespace mediapipe::tasks::audio::audio_embedder { + +// Alias the shared EmbeddingResult struct as result type. +using AudioEmbedderResult = + ::mediapipe::tasks::components::containers::EmbeddingResult; + +struct AudioEmbedderOptions { + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // Options for configuring the embedder behavior, such as score threshold, + // number of results, etc. + components::processors::EmbedderOptions embedder_options; + + // The running mode of the audio embedder. Default to the audio clips mode. + // Audio embedder has two running modes: + // 1) The audio clips mode for running embedding on independent audio clips. + // 2) The audio stream mode for running embedding on the audio stream, + // such as from microphone. In this mode, the "result_callback" below must + // be specified to receive the embedding results asynchronously. + core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS; + + // The user-defined result callback for processing audio stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::AUDIO_STREAM. + std::function)> result_callback = + nullptr; +}; + +// Performs embedding extraction on audio clips or audio stream. +// +// The API expects a TFLite model with TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteFloat32) +// - input audio buffer of size `[batch * samples]`. +// - batch inference is not supported (`batch` is required to be 1). +// - for multi-channel models, the channels need be interleaved. +// At least one output tensor with: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N` components corresponding to the `N` dimensions of the returned +// feature vector for this output layer. +// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +class AudioEmbedder : core::BaseAudioTaskApi { + public: + using BaseAudioTaskApi::BaseAudioTaskApi; + + // Creates an AudioEmbedder from the provided options. A non-default + // OpResolver can be specified in the BaseOptions in order to support custom + // Ops or specify a subset of built-in Ops. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs embedding extraction on the provided audio clips. Only use this + // method when the AudioEmbedder is created with the audio clips running mode. + // + // The audio clip is represented as a MediaPipe Matrix that has the number of + // channels rows and the number of samples per channel columns. The method + // accepts audio clips with various length and audio sample rate. It's + // required to provide the corresponding audio sample rate along with the + // input audio clips. + // + // The input audio clip may be longer than what the model is able to process + // in a single inference. When this occurs, the input audio clip is split into + // multiple chunks starting at different timestamps. For this reason, this + // function returns a vector of EmbeddingResult objects, each associated + // with a timestamp corresponding to the start (in milliseconds) of the chunk + // data that was extracted. + absl::StatusOr> Embed( + Matrix audio_clip, double audio_sample_rate); + + // Sends audio stream data to embedder, and the results will be available via + // the "result_callback" provided in the AudioEmbedderOptions. Only use this + // method when the AudioEmbedder is created with the audio stream running + // mode. + // + // The audio block is represented as a MediaPipe Matrix that has the number + // of channels rows and the number of samples per channel columns. The audio + // data will be resampled, accumulated, and framed to the proper size for the + // underlying model to consume. It's required to provide the corresponding + // audio sample rate along with the input audio block as well as a timestamp + // (in milliseconds) to indicate the start time of the input audio block. The + // timestamps must be monotonically increasing. + // + // The input audio block may be longer than what the model is able to process + // in a single inference. When this occurs, the input audio block is split + // into multiple chunks. For this reason, the callback may be called multiple + // times (once per chunk) for each call to this function. + absl::Status EmbedAsync(Matrix audio_block, double audio_sample_rate, + int64 timestamp_ms); + + // Shuts down the AudioEmbedder when all works are done. + absl::Status Close() { return runner_->Close(); } + + // Utility function to compute cosine similarity [1] between two embeddings. + // May return an InvalidArgumentError if e.g. the embeddings are of different + // types (quantized vs. float), have different sizes, or have a an L2-norm of + // 0. + // + // [1]: https://en.wikipedia.org/wiki/Cosine_similarity + static absl::StatusOr CosineSimilarity( + const components::containers::Embedding& u, + const components::containers::Embedding& v); +}; + +} // namespace mediapipe::tasks::audio::audio_embedder + +#endif // MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_ diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc new file mode 100644 index 000000000..7667feaa3 --- /dev/null +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -0,0 +1,186 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h" +#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h" +#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe::tasks::audio::audio_embedder { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; + +constexpr char kAudioTag[] = "AUDIO"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; +constexpr char kSampleRateTag[] = "SAMPLE_RATE"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTimestampsTag[] = "TIMESTAMPS"; + +// Struct holding the different output streams produced by the audio embedder +// graph. +struct AudioEmbedderOutputStreams { + Source embeddings; + Source> timestamped_embeddings; +}; + +// Builds an AudioTensorSpecs for configuring the preprocessing calculators. +absl::StatusOr BuildPreprocessingSpecs( + const core::ModelResources& model_resources) { + const tflite::Model& model = *model_resources.GetTfLiteModel(); + if (model.subgraphs()->size() != 1) { + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Audio embedding tflite models are " + "assumed to have a single subgraph.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + const auto* primary_subgraph = (*model.subgraphs())[0]; + if (primary_subgraph->inputs()->size() != 1) { + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Audio embedding tflite models are " + "assumed to have a single input.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + const auto* input_tensor = + (*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]]; + ASSIGN_OR_RETURN( + const auto* audio_tensor_metadata, + GetAudioTensorMetadataIfAny(*model_resources.GetMetadataExtractor(), 0)); + return BuildInputAudioTensorSpecs(*input_tensor, audio_tensor_metadata); +} + +// Fills in the AudioToTensorCalculatorOptions based on the AudioTensorSpecs. +void ConfigureAudioToTensorCalculator( + const AudioTensorSpecs& audio_tensor_specs, bool use_stream_mode, + AudioToTensorCalculatorOptions* options) { + options->set_num_channels(audio_tensor_specs.num_channels); + options->set_num_samples(audio_tensor_specs.num_samples); + options->set_target_sample_rate(audio_tensor_specs.sample_rate); + options->set_stream_mode(use_stream_mode); +} +} // namespace + +class AudioEmbedderGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN( + auto output_streams, + BuildAudioEmbeddingTask( + sc->Options(), *model_resources, + graph[Input(kAudioTag)], + absl::make_optional(graph[Input(kSampleRateTag)]), graph)); + output_streams.embeddings >> graph[Output(kEmbeddingsTag)]; + output_streams.timestamped_embeddings >> + graph[Output>(kTimestampedEmbeddingsTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr BuildAudioEmbeddingTask( + const proto::AudioEmbedderGraphOptions& task_options, + const core::ModelResources& model_resources, Source audio_in, + absl::optional> sample_rate_in, Graph& graph) { + const bool use_stream_mode = task_options.base_options().use_stream_mode(); + const auto* metadata_extractor = model_resources.GetMetadataExtractor(); + // Checks that metadata is available. + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == + nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Audio embedder models require TFLite Model Metadata but none was " + "found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + // Adds AudioToTensorCalculator and connects it to the graph input streams. + ASSIGN_OR_RETURN(auto audio_tensor_specs, + BuildPreprocessingSpecs(model_resources)); + auto& audio_to_tensor = graph.AddNode("AudioToTensorCalculator"); + ConfigureAudioToTensorCalculator( + audio_tensor_specs, use_stream_mode, + &audio_to_tensor.GetOptions()); + audio_in >> audio_to_tensor.In(kAudioTag); + if (sample_rate_in.has_value()) { + sample_rate_in.value() >> audio_to_tensor.In(kSampleRateTag); + } + + // Adds inference subgraph and connects its input stream to the output + // tensors produced by the AudioToTensorCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag); + // Adds postprocessing calculators and connects its input stream to the + // inference results. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( + model_resources, task_options.embedder_options(), + &postprocessing.GetOptions())); + inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); + // Time aggregation is only needed for performing audio embedding on + // audio files. Disables timestamp aggregation by not connecting the + // "TIMESTAMPS" streams. + if (!use_stream_mode) { + audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag); + } + + // Outputs both streams as graph output streams/ + return AudioEmbedderOutputStreams{ + /*embeddings=*/postprocessing[Output(kEmbeddingsTag)], + /*timestamped_embeddings=*/ + postprocessing[Output>( + kTimestampedEmbeddingsTag)], + }; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::audio::audio_embedder::AudioEmbedderGraph); + +} // namespace mediapipe::tasks::audio::audio_embedder diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc new file mode 100644 index 000000000..749066ead --- /dev/null +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc @@ -0,0 +1,289 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/audio/core/running_mode.h" +#include "mediapipe/tasks/cc/audio/utils/test_utils.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace audio { +namespace audio_embedder { +namespace { + +using ::absl::StatusOr; +using ::mediapipe::file::JoinPath; +using ::testing::HasSubstr; +using ::testing::Optional; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/audio"; +constexpr char kModelWithMetadata[] = "yamnet_embedding_metadata.tflite"; +constexpr char k16kTestWavFilename[] = "speech_16000_hz_mono.wav"; +constexpr char k48kTestWavFilename[] = "speech_48000_hz_mono.wav"; +constexpr char k16kTestWavForTwoHeadsFilename[] = "two_heads_16000_hz_mono.wav"; +constexpr float kSpeechSimilarities[] = {0.985359, 0.994349, 0.993227, 0.996658, + 0.996384}; +constexpr int kMilliSecondsPerSecond = 1000; +constexpr int kYamnetNumOfAudioSamples = 15600; +constexpr int kYamnetAudioSampleRate = 16000; + +Matrix GetAudioData(absl::string_view filename) { + std::string wav_file_path = JoinPath("./", kTestDataDirectory, filename); + int buffer_size; + auto audio_data = internal::ReadWavFile(wav_file_path, &buffer_size); + Eigen::Map matrix_mapping(audio_data->get(), 1, buffer_size); + return matrix_mapping.matrix(); +} + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + auto audio_embedder = + AudioEmbedder::Create(std::make_unique()); + + EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + audio_embedder.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_embedder, + AudioEmbedder::Create(std::move(options))); +} + +TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInAudioClipsMode) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + options->running_mode = core::RunningMode::AUDIO_CLIPS; + options->result_callback = [](absl::StatusOr) {}; + + auto audio_embedder = AudioEmbedder::Create(std::move(options)); + + EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + audio_embedder.status().message(), + HasSubstr("a user-defined result callback shouldn't be provided")); + EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + options->running_mode = core::RunningMode::AUDIO_STREAM; + + auto audio_embedder = AudioEmbedder::Create(std::move(options)); + + EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(audio_embedder.status().message(), + HasSubstr("a user-defined result callback must be provided")); + EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); +} + +class EmbedTest : public tflite_shims::testing::Test {}; + +TEST_F(EmbedTest, SucceedsWithSilentAudio) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + options->running_mode = core::RunningMode::AUDIO_CLIPS; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_embedder, + AudioEmbedder::Create(std::move(options))); + Matrix silent_data(1, kYamnetNumOfAudioSamples); + silent_data.setZero(); + MP_ASSERT_OK_AND_ASSIGN( + auto result, audio_embedder->Embed(silent_data, kYamnetAudioSampleRate)); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0].embeddings[0].float_embedding.size(), 1024); + constexpr float kValueDiffTolerance = 3e-6; + EXPECT_NEAR(result[0].embeddings[0].float_embedding[0], 2.07613f, + kValueDiffTolerance); + EXPECT_NEAR(result[0].embeddings[0].float_embedding[1], 0.392721f, + kValueDiffTolerance); + EXPECT_NEAR(result[0].embeddings[0].float_embedding[2], 0.543622f, + kValueDiffTolerance); +} + +TEST_F(EmbedTest, SucceedsWithSameAudioAtDifferentSampleRates) { + auto audio_buffer1 = GetAudioData(k16kTestWavFilename); + auto audio_buffer2 = GetAudioData(k48kTestWavFilename); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + options->running_mode = core::RunningMode::AUDIO_CLIPS; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_embedder, + AudioEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto result1, + audio_embedder->Embed(audio_buffer1, 16000)); + MP_ASSERT_OK_AND_ASSIGN(auto result2, + audio_embedder->Embed(audio_buffer2, 48000)); + int expected_size = sizeof(kSpeechSimilarities) / sizeof(float); + ASSERT_EQ(result1.size(), expected_size); + ASSERT_EQ(result2.size(), expected_size); + for (int i = 0; i < expected_size; ++i) { + MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity( + result1[i].embeddings[0], + result2[i].embeddings[0])); + EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6); + } + MP_EXPECT_OK(audio_embedder->Close()); +} + +TEST_F(EmbedTest, SucceedsWithDifferentAudios) { + auto audio_buffer1 = GetAudioData(k16kTestWavFilename); + auto audio_buffer2 = GetAudioData(k16kTestWavForTwoHeadsFilename); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + options->running_mode = core::RunningMode::AUDIO_CLIPS; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_embedder, + AudioEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + auto result1, + audio_embedder->Embed(audio_buffer1, kYamnetAudioSampleRate)); + MP_ASSERT_OK_AND_ASSIGN( + auto result2, + audio_embedder->Embed(audio_buffer2, kYamnetAudioSampleRate)); + ASSERT_EQ(result1.size(), 5); + ASSERT_EQ(result2.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity( + result1[0].embeddings[0], + result2[0].embeddings[0])); + EXPECT_NEAR(similarity, 0.09017f, 1e-6); + MP_EXPECT_OK(audio_embedder->Close()); +} + +class EmbedAsyncTest : public tflite_shims::testing::Test { + protected: + void RunAudioEmbedderInStreamMode(std::string audio_file_name, + int sample_rate_hz, + std::vector* result) { + auto audio_buffer = GetAudioData(audio_file_name); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + options->running_mode = core::RunningMode::AUDIO_STREAM; + options->result_callback = + [result](absl::StatusOr status_or_result) { + MP_ASSERT_OK_AND_ASSIGN(result->emplace_back(), status_or_result); + }; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_embedder, + AudioEmbedder::Create(std::move(options))); + int start_col = 0; + static unsigned int rseed = 0; + while (start_col < audio_buffer.cols()) { + int num_samples = std::min( + (int)(audio_buffer.cols() - start_col), + rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * sample_rate_hz / + kYamnetAudioSampleRate); + MP_ASSERT_OK(audio_embedder->EmbedAsync( + audio_buffer.block(0, start_col, 1, num_samples), sample_rate_hz, + start_col * kMilliSecondsPerSecond / sample_rate_hz)); + start_col += num_samples; + } + MP_ASSERT_OK(audio_embedder->Close()); + } +}; + +TEST_F(EmbedAsyncTest, FailsWithOutOfOrderInputTimestamps) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kModelWithMetadata); + options->running_mode = core::RunningMode::AUDIO_STREAM; + options->result_callback = + [](absl::StatusOr status_or_result) { return; }; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_embedder, + AudioEmbedder::Create(std::move(options))); + MP_ASSERT_OK(audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples), + kYamnetAudioSampleRate, 100)); + auto status = audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples), + kYamnetAudioSampleRate, 0); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("timestamp must be monotonically increasing")); + EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInvalidTimestampError)))); + MP_ASSERT_OK(audio_embedder->Close()); +} + +TEST_F(EmbedAsyncTest, SucceedsWithSameAudioAtDifferentSampleRates) { + std::vector result1; + RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1); + std::vector result2; + RunAudioEmbedderInStreamMode(k48kTestWavFilename, 48000, &result2); + int expected_size = sizeof(kSpeechSimilarities) / sizeof(float); + ASSERT_EQ(result1.size(), expected_size); + ASSERT_EQ(result2.size(), expected_size); + for (int i = 0; i < expected_size; ++i) { + MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity( + result1[i].embeddings[0], + result2[i].embeddings[0])); + EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6); + } +} + +TEST_F(EmbedAsyncTest, SucceedsWithDifferentAudios) { + std::vector result1; + RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1); + std::vector result2; + RunAudioEmbedderInStreamMode(k16kTestWavForTwoHeadsFilename, 16000, &result2); + ASSERT_EQ(result1.size(), 5); + ASSERT_EQ(result2.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity( + result1[0].embeddings[0], + result2[0].embeddings[0])); + EXPECT_NEAR(similarity, 0.09017f, 1e-6); +} + +} // namespace +} // namespace audio_embedder +} // namespace audio +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD new file mode 100644 index 000000000..38df8fb44 --- /dev/null +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "audio_embedder_graph_options_proto", + srcs = ["audio_embedder_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto new file mode 100644 index 000000000..25c5d5474 --- /dev/null +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto @@ -0,0 +1,39 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.audio.audio_embedder.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.audio.audioembedder.proto"; +option java_outer_classname = "AudioEmbedderGraphOptionsProto"; + +message AudioEmbedderGraphOptions { + extend mediapipe.CalculatorOptions { + optional AudioEmbedderGraphOptions ext = 487277289; + } + + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring the embedder behavior, such as normalization or + // quantization. + optional components.processors.proto.EmbedderOptions embedder_options = 2; +} diff --git a/mediapipe/tasks/testdata/audio/BUILD b/mediapipe/tasks/testdata/audio/BUILD index ce4ab2dd9..32b812023 100644 --- a/mediapipe/tasks/testdata/audio/BUILD +++ b/mediapipe/tasks/testdata/audio/BUILD @@ -30,6 +30,7 @@ mediapipe_files(srcs = [ "two_heads_16000_hz_mono.wav", "two_heads_44100_hz_mono.wav", "yamnet_audio_classifier_with_metadata.tflite", + "yamnet_embedding_metadata.tflite", ]) filegroup( @@ -38,6 +39,7 @@ filegroup( "model_without_metadata.tflite", "two_heads.tflite", "yamnet_audio_classifier_with_metadata.tflite", + "yamnet_embedding_metadata.tflite", ], ) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index b4ec3b36c..1f0b00289 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -982,6 +982,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"], ) + http_file( + name = "com_google_mediapipe_yamnet_embedding_metadata_tflite", + sha256 = "7baa72708e3919bae5a5dc78d932847bc28008af14febd083eff62d28af9c72a", + urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"], + ) + http_file( name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3",