MediaPipe Tasks Audio embedder C++ API.
PiperOrigin-RevId: 488273381
This commit is contained in:
parent
0dfa91a166
commit
6c0ca947de
79
mediapipe/tasks/cc/audio/audio_embedder/BUILD
Normal file
79
mediapipe/tasks/cc/audio/audio_embedder/BUILD
Normal file
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "audio_embedder_graph",
|
||||
srcs = ["audio_embedder_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/audio:time_series_framer_calculator",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
|
||||
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
|
||||
"//mediapipe/calculators/tensor:audio_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "audio_embedder",
|
||||
srcs = ["audio_embedder.cc"],
|
||||
hdrs = ["audio_embedder.h"],
|
||||
deps = [
|
||||
":audio_embedder_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build
|
156
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc
Normal file
156
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc
Normal file
|
@ -0,0 +1,156 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe::tasks::audio::audio_embedder {
|
||||
namespace {
|
||||
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
constexpr char kAudioStreamName[] = "audio_in";
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kEmbeddingsName[] = "embeddings_out";
|
||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||
constexpr char kSampleRateName[] = "sample_rate_in";
|
||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// type "AudioEmbedderGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<proto::AudioEmbedderGraphOptions> options_proto) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
||||
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
||||
subgraph.In(kSampleRateTag);
|
||||
subgraph.GetOptions<proto::AudioEmbedderGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
subgraph.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
|
||||
graph.Out(kEmbeddingsTag);
|
||||
subgraph.Out(kTimestampedEmbeddingsTag).SetName(kTimestampedEmbeddingsName) >>
|
||||
graph.Out(kTimestampedEmbeddingsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing AudioEmbedderOptions struct to the internal
|
||||
// AudioEmbedderGraphOptions proto.
|
||||
std::unique_ptr<proto::AudioEmbedderGraphOptions>
|
||||
ConvertAudioEmbedderOptionsToProto(AudioEmbedderOptions* options) {
|
||||
auto options_proto = std::make_unique<proto::AudioEmbedderGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
||||
auto embedder_options_proto =
|
||||
std::make_unique<components::processors::proto::EmbedderOptions>(
|
||||
components::processors::ConvertEmbedderOptionsToProto(
|
||||
&(options->embedder_options)));
|
||||
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<AudioEmbedderResult>> ConvertOutputPackets(
|
||||
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||
if (!status_or_packets.ok()) {
|
||||
return status_or_packets.status();
|
||||
}
|
||||
auto embedding_results = status_or_packets.value()[kTimestampedEmbeddingsName]
|
||||
.Get<std::vector<EmbeddingResult>>();
|
||||
std::vector<AudioEmbedderResult> results;
|
||||
results.reserve(embedding_results.size());
|
||||
for (const auto& embedding_result : embedding_results) {
|
||||
results.emplace_back(ConvertToEmbeddingResult(embedding_result));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
absl::StatusOr<AudioEmbedderResult> ConvertAsyncOutputPackets(
|
||||
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||
if (!status_or_packets.ok()) {
|
||||
return status_or_packets.status();
|
||||
}
|
||||
return ConvertToEmbeddingResult(
|
||||
status_or_packets.value()[kEmbeddingsName].Get<EmbeddingResult>());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* static */
|
||||
absl::StatusOr<std::unique_ptr<AudioEmbedder>> AudioEmbedder::Create(
|
||||
std::unique_ptr<AudioEmbedderOptions> options) {
|
||||
auto options_proto = ConvertAudioEmbedderOptionsToProto(options.get());
|
||||
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||
if (options->result_callback) {
|
||||
auto result_callback = options->result_callback;
|
||||
packets_callback =
|
||||
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||
result_callback(ConvertAsyncOutputPackets(status_or_packets));
|
||||
};
|
||||
}
|
||||
return core::AudioTaskApiFactory::Create<AudioEmbedder,
|
||||
proto::AudioEmbedderGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<AudioEmbedderResult>> AudioEmbedder::Embed(
|
||||
Matrix audio_clip, double audio_sample_rate) {
|
||||
return ConvertOutputPackets(ProcessAudioClip(
|
||||
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},
|
||||
{kSampleRateName, MakePacket<double>(audio_sample_rate)}}));
|
||||
}
|
||||
|
||||
absl::Status AudioEmbedder::EmbedAsync(Matrix audio_block,
|
||||
double audio_sample_rate,
|
||||
int64 timestamp_ms) {
|
||||
MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate));
|
||||
return SendAudioStreamData(
|
||||
{{kAudioStreamName,
|
||||
MakePacket<Matrix>(std::move(audio_block))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
absl::StatusOr<double> AudioEmbedder::CosineSimilarity(
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v) {
|
||||
return components::utils::CosineSimilarity(u, v);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::audio::audio_embedder
|
139
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h
Normal file
139
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h
Normal file
|
@ -0,0 +1,139 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
|
||||
namespace mediapipe::tasks::audio::audio_embedder {
|
||||
|
||||
// Alias the shared EmbeddingResult struct as result type.
|
||||
using AudioEmbedderResult =
|
||||
::mediapipe::tasks::components::containers::EmbeddingResult;
|
||||
|
||||
struct AudioEmbedderOptions {
|
||||
// Base options for configuring Task library, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Options for configuring the embedder behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::processors::EmbedderOptions embedder_options;
|
||||
|
||||
// The running mode of the audio embedder. Default to the audio clips mode.
|
||||
// Audio embedder has two running modes:
|
||||
// 1) The audio clips mode for running embedding on independent audio clips.
|
||||
// 2) The audio stream mode for running embedding on the audio stream,
|
||||
// such as from microphone. In this mode, the "result_callback" below must
|
||||
// be specified to receive the embedding results asynchronously.
|
||||
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||
|
||||
// The user-defined result callback for processing audio stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::AUDIO_STREAM.
|
||||
std::function<void(absl::StatusOr<AudioEmbedderResult>)> result_callback =
|
||||
nullptr;
|
||||
};
|
||||
|
||||
// Performs embedding extraction on audio clips or audio stream.
|
||||
//
|
||||
// The API expects a TFLite model with TFLite Model Metadata.
|
||||
//
|
||||
// Input tensor:
|
||||
// (kTfLiteFloat32)
|
||||
// - input audio buffer of size `[batch * samples]`.
|
||||
// - batch inference is not supported (`batch` is required to be 1).
|
||||
// - for multi-channel models, the channels need be interleaved.
|
||||
// At least one output tensor with:
|
||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||
// - `N` components corresponding to the `N` dimensions of the returned
|
||||
// feature vector for this output layer.
|
||||
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
|
||||
class AudioEmbedder : core::BaseAudioTaskApi {
|
||||
public:
|
||||
using BaseAudioTaskApi::BaseAudioTaskApi;
|
||||
|
||||
// Creates an AudioEmbedder from the provided options. A non-default
|
||||
// OpResolver can be specified in the BaseOptions in order to support custom
|
||||
// Ops or specify a subset of built-in Ops.
|
||||
static absl::StatusOr<std::unique_ptr<AudioEmbedder>> Create(
|
||||
std::unique_ptr<AudioEmbedderOptions> options);
|
||||
|
||||
// Performs embedding extraction on the provided audio clips. Only use this
|
||||
// method when the AudioEmbedder is created with the audio clips running mode.
|
||||
//
|
||||
// The audio clip is represented as a MediaPipe Matrix that has the number of
|
||||
// channels rows and the number of samples per channel columns. The method
|
||||
// accepts audio clips with various length and audio sample rate. It's
|
||||
// required to provide the corresponding audio sample rate along with the
|
||||
// input audio clips.
|
||||
//
|
||||
// The input audio clip may be longer than what the model is able to process
|
||||
// in a single inference. When this occurs, the input audio clip is split into
|
||||
// multiple chunks starting at different timestamps. For this reason, this
|
||||
// function returns a vector of EmbeddingResult objects, each associated
|
||||
// with a timestamp corresponding to the start (in milliseconds) of the chunk
|
||||
// data that was extracted.
|
||||
absl::StatusOr<std::vector<AudioEmbedderResult>> Embed(
|
||||
Matrix audio_clip, double audio_sample_rate);
|
||||
|
||||
// Sends audio stream data to embedder, and the results will be available via
|
||||
// the "result_callback" provided in the AudioEmbedderOptions. Only use this
|
||||
// method when the AudioEmbedder is created with the audio stream running
|
||||
// mode.
|
||||
//
|
||||
// The audio block is represented as a MediaPipe Matrix that has the number
|
||||
// of channels rows and the number of samples per channel columns. The audio
|
||||
// data will be resampled, accumulated, and framed to the proper size for the
|
||||
// underlying model to consume. It's required to provide the corresponding
|
||||
// audio sample rate along with the input audio block as well as a timestamp
|
||||
// (in milliseconds) to indicate the start time of the input audio block. The
|
||||
// timestamps must be monotonically increasing.
|
||||
//
|
||||
// The input audio block may be longer than what the model is able to process
|
||||
// in a single inference. When this occurs, the input audio block is split
|
||||
// into multiple chunks. For this reason, the callback may be called multiple
|
||||
// times (once per chunk) for each call to this function.
|
||||
absl::Status EmbedAsync(Matrix audio_block, double audio_sample_rate,
|
||||
int64 timestamp_ms);
|
||||
|
||||
// Shuts down the AudioEmbedder when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
||||
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||
// 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
static absl::StatusOr<double> CosineSimilarity(
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v);
|
||||
};
|
||||
|
||||
} // namespace mediapipe::tasks::audio::audio_embedder
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_EMBEDDER_AUDIO_EMBEDDER_H_
|
186
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc
Normal file
186
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc
Normal file
|
@ -0,0 +1,186 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mediapipe::tasks::audio::audio_embedder {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
|
||||
// Struct holding the different output streams produced by the audio embedder
|
||||
// graph.
|
||||
struct AudioEmbedderOutputStreams {
|
||||
Source<EmbeddingResult> embeddings;
|
||||
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
|
||||
};
|
||||
|
||||
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
|
||||
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
|
||||
const core::ModelResources& model_resources) {
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
if (model.subgraphs()->size() != 1) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
"Audio embedding tflite models are "
|
||||
"assumed to have a single subgraph.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
if (primary_subgraph->inputs()->size() != 1) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
"Audio embedding tflite models are "
|
||||
"assumed to have a single input.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* input_tensor =
|
||||
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* audio_tensor_metadata,
|
||||
GetAudioTensorMetadataIfAny(*model_resources.GetMetadataExtractor(), 0));
|
||||
return BuildInputAudioTensorSpecs(*input_tensor, audio_tensor_metadata);
|
||||
}
|
||||
|
||||
// Fills in the AudioToTensorCalculatorOptions based on the AudioTensorSpecs.
|
||||
void ConfigureAudioToTensorCalculator(
|
||||
const AudioTensorSpecs& audio_tensor_specs, bool use_stream_mode,
|
||||
AudioToTensorCalculatorOptions* options) {
|
||||
options->set_num_channels(audio_tensor_specs.num_channels);
|
||||
options->set_num_samples(audio_tensor_specs.num_samples);
|
||||
options->set_target_sample_rate(audio_tensor_specs.sample_rate);
|
||||
options->set_stream_mode(use_stream_mode);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class AudioEmbedderGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<proto::AudioEmbedderGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildAudioEmbeddingTask(
|
||||
sc->Options<proto::AudioEmbedderGraphOptions>(), *model_resources,
|
||||
graph[Input<Matrix>(kAudioTag)],
|
||||
absl::make_optional(graph[Input<double>(kSampleRateTag)]), graph));
|
||||
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
output_streams.timestamped_embeddings >>
|
||||
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<AudioEmbedderOutputStreams> BuildAudioEmbeddingTask(
|
||||
const proto::AudioEmbedderGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||
const bool use_stream_mode = task_options.base_options().use_stream_mode();
|
||||
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
||||
// Checks that metadata is available.
|
||||
if (metadata_extractor->GetModelMetadata() == nullptr ||
|
||||
metadata_extractor->GetModelMetadata()->subgraph_metadata() ==
|
||||
nullptr) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Audio embedder models require TFLite Model Metadata but none was "
|
||||
"found",
|
||||
MediaPipeTasksStatus::kMetadataNotFoundError);
|
||||
}
|
||||
// Adds AudioToTensorCalculator and connects it to the graph input streams.
|
||||
ASSIGN_OR_RETURN(auto audio_tensor_specs,
|
||||
BuildPreprocessingSpecs(model_resources));
|
||||
auto& audio_to_tensor = graph.AddNode("AudioToTensorCalculator");
|
||||
ConfigureAudioToTensorCalculator(
|
||||
audio_tensor_specs, use_stream_mode,
|
||||
&audio_to_tensor.GetOptions<AudioToTensorCalculatorOptions>());
|
||||
audio_in >> audio_to_tensor.In(kAudioTag);
|
||||
if (sample_rate_in.has_value()) {
|
||||
sample_rate_in.value() >> audio_to_tensor.In(kSampleRateTag);
|
||||
}
|
||||
|
||||
// Adds inference subgraph and connects its input stream to the output
|
||||
// tensors produced by the AudioToTensorCalculator.
|
||||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||
// Adds postprocessing calculators and connects its input stream to the
|
||||
// inference results.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
// Time aggregation is only needed for performing audio embedding on
|
||||
// audio files. Disables timestamp aggregation by not connecting the
|
||||
// "TIMESTAMPS" streams.
|
||||
if (!use_stream_mode) {
|
||||
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
|
||||
}
|
||||
|
||||
// Outputs both streams as graph output streams/
|
||||
return AudioEmbedderOutputStreams{
|
||||
/*embeddings=*/postprocessing[Output<EmbeddingResult>(kEmbeddingsTag)],
|
||||
/*timestamped_embeddings=*/
|
||||
postprocessing[Output<std::vector<EmbeddingResult>>(
|
||||
kTimestampedEmbeddingsTag)],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::audio::audio_embedder::AudioEmbedderGraph);
|
||||
|
||||
} // namespace mediapipe::tasks::audio::audio_embedder
|
289
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc
Normal file
289
mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc
Normal file
|
@ -0,0 +1,289 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <new>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/cord.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_embedder {
|
||||
namespace {
|
||||
|
||||
using ::absl::StatusOr;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/audio";
|
||||
constexpr char kModelWithMetadata[] = "yamnet_embedding_metadata.tflite";
|
||||
constexpr char k16kTestWavFilename[] = "speech_16000_hz_mono.wav";
|
||||
constexpr char k48kTestWavFilename[] = "speech_48000_hz_mono.wav";
|
||||
constexpr char k16kTestWavForTwoHeadsFilename[] = "two_heads_16000_hz_mono.wav";
|
||||
constexpr float kSpeechSimilarities[] = {0.985359, 0.994349, 0.993227, 0.996658,
|
||||
0.996384};
|
||||
constexpr int kMilliSecondsPerSecond = 1000;
|
||||
constexpr int kYamnetNumOfAudioSamples = 15600;
|
||||
constexpr int kYamnetAudioSampleRate = 16000;
|
||||
|
||||
Matrix GetAudioData(absl::string_view filename) {
|
||||
std::string wav_file_path = JoinPath("./", kTestDataDirectory, filename);
|
||||
int buffer_size;
|
||||
auto audio_data = internal::ReadWavFile(wav_file_path, &buffer_size);
|
||||
Eigen::Map<Matrix> matrix_mapping(audio_data->get(), 1, buffer_size);
|
||||
return matrix_mapping.matrix();
|
||||
}
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||
auto audio_embedder =
|
||||
AudioEmbedder::Create(std::make_unique<AudioEmbedderOptions>());
|
||||
|
||||
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
audio_embedder.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||
AudioEmbedder::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInAudioClipsMode) {
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||
options->result_callback = [](absl::StatusOr<AudioEmbedderResult>) {};
|
||||
|
||||
auto audio_embedder = AudioEmbedder::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
audio_embedder.status().message(),
|
||||
HasSubstr("a user-defined result callback shouldn't be provided"));
|
||||
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
|
||||
auto audio_embedder = AudioEmbedder::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(audio_embedder.status().message(),
|
||||
HasSubstr("a user-defined result callback must be provided"));
|
||||
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
class EmbedTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||
AudioEmbedder::Create(std::move(options)));
|
||||
Matrix silent_data(1, kYamnetNumOfAudioSamples);
|
||||
silent_data.setZero();
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result, audio_embedder->Embed(silent_data, kYamnetAudioSampleRate));
|
||||
EXPECT_EQ(result.size(), 1);
|
||||
EXPECT_EQ(result[0].embeddings[0].float_embedding.size(), 1024);
|
||||
constexpr float kValueDiffTolerance = 3e-6;
|
||||
EXPECT_NEAR(result[0].embeddings[0].float_embedding[0], 2.07613f,
|
||||
kValueDiffTolerance);
|
||||
EXPECT_NEAR(result[0].embeddings[0].float_embedding[1], 0.392721f,
|
||||
kValueDiffTolerance);
|
||||
EXPECT_NEAR(result[0].embeddings[0].float_embedding[2], 0.543622f,
|
||||
kValueDiffTolerance);
|
||||
}
|
||||
|
||||
TEST_F(EmbedTest, SucceedsWithSameAudioAtDifferentSampleRates) {
|
||||
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
|
||||
auto audio_buffer2 = GetAudioData(k48kTestWavFilename);
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||
AudioEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result1,
|
||||
audio_embedder->Embed(audio_buffer1, 16000));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result2,
|
||||
audio_embedder->Embed(audio_buffer2, 48000));
|
||||
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
|
||||
ASSERT_EQ(result1.size(), expected_size);
|
||||
ASSERT_EQ(result2.size(), expected_size);
|
||||
for (int i = 0; i < expected_size; ++i) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||
result1[i].embeddings[0],
|
||||
result2[i].embeddings[0]));
|
||||
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
|
||||
}
|
||||
MP_EXPECT_OK(audio_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
|
||||
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
|
||||
auto audio_buffer2 = GetAudioData(k16kTestWavForTwoHeadsFilename);
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||
AudioEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result1,
|
||||
audio_embedder->Embed(audio_buffer1, kYamnetAudioSampleRate));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result2,
|
||||
audio_embedder->Embed(audio_buffer2, kYamnetAudioSampleRate));
|
||||
ASSERT_EQ(result1.size(), 5);
|
||||
ASSERT_EQ(result2.size(), 1);
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||
result1[0].embeddings[0],
|
||||
result2[0].embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
|
||||
MP_EXPECT_OK(audio_embedder->Close());
|
||||
}
|
||||
|
||||
class EmbedAsyncTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
|
||||
int sample_rate_hz,
|
||||
std::vector<AudioEmbedderResult>* result) {
|
||||
auto audio_buffer = GetAudioData(audio_file_name);
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->result_callback =
|
||||
[result](absl::StatusOr<AudioEmbedderResult> status_or_result) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(result->emplace_back(), status_or_result);
|
||||
};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||
AudioEmbedder::Create(std::move(options)));
|
||||
int start_col = 0;
|
||||
static unsigned int rseed = 0;
|
||||
while (start_col < audio_buffer.cols()) {
|
||||
int num_samples = std::min(
|
||||
(int)(audio_buffer.cols() - start_col),
|
||||
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * sample_rate_hz /
|
||||
kYamnetAudioSampleRate);
|
||||
MP_ASSERT_OK(audio_embedder->EmbedAsync(
|
||||
audio_buffer.block(0, start_col, 1, num_samples), sample_rate_hz,
|
||||
start_col * kMilliSecondsPerSecond / sample_rate_hz));
|
||||
start_col += num_samples;
|
||||
}
|
||||
MP_ASSERT_OK(audio_embedder->Close());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(EmbedAsyncTest, FailsWithOutOfOrderInputTimestamps) {
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<AudioEmbedderResult> status_or_result) { return; };
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
|
||||
AudioEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK(audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
|
||||
kYamnetAudioSampleRate, 100));
|
||||
auto status = audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
|
||||
kYamnetAudioSampleRate, 0);
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(status.message(),
|
||||
HasSubstr("timestamp must be monotonically increasing"));
|
||||
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInvalidTimestampError))));
|
||||
MP_ASSERT_OK(audio_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedAsyncTest, SucceedsWithSameAudioAtDifferentSampleRates) {
|
||||
std::vector<AudioEmbedderResult> result1;
|
||||
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
|
||||
std::vector<AudioEmbedderResult> result2;
|
||||
RunAudioEmbedderInStreamMode(k48kTestWavFilename, 48000, &result2);
|
||||
int expected_size = sizeof(kSpeechSimilarities) / sizeof(float);
|
||||
ASSERT_EQ(result1.size(), expected_size);
|
||||
ASSERT_EQ(result2.size(), expected_size);
|
||||
for (int i = 0; i < expected_size; ++i) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||
result1[i].embeddings[0],
|
||||
result2[i].embeddings[0]));
|
||||
EXPECT_NEAR(similarity, kSpeechSimilarities[i], 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(EmbedAsyncTest, SucceedsWithDifferentAudios) {
|
||||
std::vector<AudioEmbedderResult> result1;
|
||||
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
|
||||
std::vector<AudioEmbedderResult> result2;
|
||||
RunAudioEmbedderInStreamMode(k16kTestWavForTwoHeadsFilename, 16000, &result2);
|
||||
ASSERT_EQ(result1.size(), 5);
|
||||
ASSERT_EQ(result2.size(), 1);
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, AudioEmbedder::CosineSimilarity(
|
||||
result1[0].embeddings[0],
|
||||
result2[0].embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.09017f, 1e-6);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace audio_embedder
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
30
mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD
Normal file
30
mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "audio_embedder_graph_options_proto",
|
||||
srcs = ["audio_embedder_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.audio.audio_embedder.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.audio.audioembedder.proto";
|
||||
option java_outer_classname = "AudioEmbedderGraphOptionsProto";
|
||||
|
||||
message AudioEmbedderGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional AudioEmbedderGraphOptions ext = 487277289;
|
||||
}
|
||||
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for configuring the embedder behavior, such as normalization or
|
||||
// quantization.
|
||||
optional components.processors.proto.EmbedderOptions embedder_options = 2;
|
||||
}
|
2
mediapipe/tasks/testdata/audio/BUILD
vendored
2
mediapipe/tasks/testdata/audio/BUILD
vendored
|
@ -30,6 +30,7 @@ mediapipe_files(srcs = [
|
|||
"two_heads_16000_hz_mono.wav",
|
||||
"two_heads_44100_hz_mono.wav",
|
||||
"yamnet_audio_classifier_with_metadata.tflite",
|
||||
"yamnet_embedding_metadata.tflite",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
|
@ -38,6 +39,7 @@ filegroup(
|
|||
"model_without_metadata.tflite",
|
||||
"two_heads.tflite",
|
||||
"yamnet_audio_classifier_with_metadata.tflite",
|
||||
"yamnet_embedding_metadata.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -982,6 +982,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_yamnet_embedding_metadata_tflite",
|
||||
sha256 = "7baa72708e3919bae5a5dc78d932847bc28008af14febd083eff62d28af9c72a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
|
||||
sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3",
|
||||
|
|
Loading…
Reference in New Issue
Block a user