Merge branch 'master' into image-embedder-python
This commit is contained in:
commit
83608d4670
|
@ -1,8 +1,10 @@
|
||||||
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||||
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
#define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/btree_map.h"
|
#include "absl/container/btree_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
@ -74,6 +76,7 @@ class TagIndexMap {
|
||||||
|
|
||||||
class Graph;
|
class Graph;
|
||||||
class NodeBase;
|
class NodeBase;
|
||||||
|
class PacketGenerator;
|
||||||
|
|
||||||
// These structs are used internally to store information about the endpoints
|
// These structs are used internally to store information about the endpoints
|
||||||
// of a connection.
|
// of a connection.
|
||||||
|
@ -146,6 +149,7 @@ template <bool IsSide, typename T>
|
||||||
class SourceImpl {
|
class SourceImpl {
|
||||||
public:
|
public:
|
||||||
using Base = SourceBase;
|
using Base = SourceBase;
|
||||||
|
using PayloadT = T;
|
||||||
|
|
||||||
// Src is used as the return type of fluent methods below. Since these are
|
// Src is used as the return type of fluent methods below. Since these are
|
||||||
// single-port methods, it is desirable to always decay to a reference to the
|
// single-port methods, it is desirable to always decay to a reference to the
|
||||||
|
@ -201,10 +205,61 @@ class SourceImpl {
|
||||||
// when building the graph.
|
// when building the graph.
|
||||||
template <typename T = internal::Generic>
|
template <typename T = internal::Generic>
|
||||||
using Source = SourceImpl<false, T>;
|
using Source = SourceImpl<false, T>;
|
||||||
|
|
||||||
|
// Represents a stream of packets of a particular type.
|
||||||
|
//
|
||||||
|
// The intended use:
|
||||||
|
// - decouple input/output streams from graph/node during graph construction
|
||||||
|
// - pass streams around and connect them as needed, extracting reusable parts
|
||||||
|
// to utility/convenience functions or classes.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
// Stream<Image> Resize(Stream<Image> image, const Size& size, Graph& graph) {
|
||||||
|
// auto& scaler_node = graph.AddNode("GlScalerCalculator");
|
||||||
|
// auto& opts = scaler_node.GetOptions<GlScalerCalculatorOptions>();
|
||||||
|
// opts.set_output_width(size.width);
|
||||||
|
// opts.set_output_height(size.height);
|
||||||
|
// a >> scaler_node.In("IMAGE");
|
||||||
|
// return scaler_node.Out("IMAGE").Cast<Image>();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Where graph can use it as:
|
||||||
|
// Graph graph;
|
||||||
|
// Stream<Image> input_image = graph.In("INPUT_IMAGE").Cast<Image>();
|
||||||
|
// Stream<Image> resized_image = Resize(input_image, {64, 64}, graph);
|
||||||
|
template <typename T>
|
||||||
|
using Stream = Source<T>;
|
||||||
|
|
||||||
template <typename T = internal::Generic>
|
template <typename T = internal::Generic>
|
||||||
using MultiSource = MultiPort<Source<T>>;
|
using MultiSource = MultiPort<Source<T>>;
|
||||||
|
|
||||||
template <typename T = internal::Generic>
|
template <typename T = internal::Generic>
|
||||||
using SideSource = SourceImpl<true, T>;
|
using SideSource = SourceImpl<true, T>;
|
||||||
|
|
||||||
|
// Represents a side packet of a particular type.
|
||||||
|
//
|
||||||
|
// The intended use:
|
||||||
|
// - decouple input/output side packets from graph/node during graph
|
||||||
|
// construction
|
||||||
|
// - pass side packets around and connect them as needed, extracting reusable
|
||||||
|
// parts utility/convenience functions or classes.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
// SidePacket<TfLiteModelPtr> GetModel(SidePacket<std::string> model_blob,
|
||||||
|
// Graph& graph) {
|
||||||
|
// auto& model_node = graph.AddNode("TfLiteModelCalculator");
|
||||||
|
// model_blob >> model_node.SideIn("MODEL_BLOB");
|
||||||
|
// return model_node.SideOut("MODEL").Cast<TfLiteModelPtr>();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Where graph can use it as:
|
||||||
|
// Graph graph;
|
||||||
|
// SidePacket<std::string> model_blob =
|
||||||
|
// graph.SideIn("MODEL_BLOB").Cast<std::string>();
|
||||||
|
// SidePacket<TfLiteModelPtr> model = GetModel(model_blob, graph);
|
||||||
|
template <typename T>
|
||||||
|
using SidePacket = SideSource<T>;
|
||||||
|
|
||||||
template <typename T = internal::Generic>
|
template <typename T = internal::Generic>
|
||||||
using MultiSideSource = MultiPort<SideSource<T>>;
|
using MultiSideSource = MultiPort<SideSource<T>>;
|
||||||
|
|
||||||
|
|
|
@ -87,6 +87,7 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "builtin_task_graphs",
|
name = "builtin_task_graphs",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
|
|
|
@ -361,7 +361,7 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
packet = mp.packet_creator.create_float(0.1)
|
packet = mp.packet_creator.create_float(0.1)
|
||||||
data = mp.packet_getter.get_float(packet)
|
data = mp.packet_getter.get_float(packet)
|
||||||
)doc",
|
)doc",
|
||||||
py::arg().noconvert(), py::return_value_policy::move);
|
py::return_value_policy::move);
|
||||||
|
|
||||||
m->def(
|
m->def(
|
||||||
"create_double", [](double data) { return MakePacket<double>(data); },
|
"create_double", [](double data) { return MakePacket<double>(data); },
|
||||||
|
@ -380,7 +380,7 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
packet = mp.packet_creator.create_double(0.1)
|
packet = mp.packet_creator.create_double(0.1)
|
||||||
data = mp.packet_getter.get_float(packet)
|
data = mp.packet_getter.get_float(packet)
|
||||||
)doc",
|
)doc",
|
||||||
py::arg().noconvert(), py::return_value_policy::move);
|
py::return_value_policy::move);
|
||||||
|
|
||||||
m->def(
|
m->def(
|
||||||
"create_int_array",
|
"create_int_array",
|
||||||
|
|
|
@ -37,7 +37,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
|
|
|
@ -63,10 +63,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
||||||
if (!options_proto->base_options().use_stream_mode()) {
|
|
||||||
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
||||||
subgraph.In(kSampleRateTag);
|
subgraph.In(kSampleRateTag);
|
||||||
}
|
|
||||||
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
||||||
options_proto.get());
|
options_proto.get());
|
||||||
subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||||
|
@ -93,9 +91,6 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
||||||
&(options->classifier_options)));
|
&(options->classifier_options)));
|
||||||
options_proto->mutable_classifier_options()->Swap(
|
options_proto->mutable_classifier_options()->Swap(
|
||||||
classifier_options_proto.get());
|
classifier_options_proto.get());
|
||||||
if (options->sample_rate > 0) {
|
|
||||||
options_proto->set_default_input_audio_sample_rate(options->sample_rate);
|
|
||||||
}
|
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,14 +124,6 @@ absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
|
||||||
/* static */
|
/* static */
|
||||||
absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
||||||
std::unique_ptr<AudioClassifierOptions> options) {
|
std::unique_ptr<AudioClassifierOptions> options) {
|
||||||
if (options->running_mode == core::RunningMode::AUDIO_STREAM &&
|
|
||||||
options->sample_rate < 0) {
|
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"The audio classifier is in audio stream mode, the sample rate must be "
|
|
||||||
"specified in the AudioClassifierOptions.",
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
|
||||||
}
|
|
||||||
auto options_proto = ConvertAudioClassifierOptionsToProto(options.get());
|
auto options_proto = ConvertAudioClassifierOptionsToProto(options.get());
|
||||||
tasks::core::PacketsCallback packets_callback = nullptr;
|
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||||
if (options->result_callback) {
|
if (options->result_callback) {
|
||||||
|
@ -161,7 +148,9 @@ absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
|
absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
|
||||||
|
double audio_sample_rate,
|
||||||
int64 timestamp_ms) {
|
int64 timestamp_ms) {
|
||||||
|
MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate));
|
||||||
return SendAudioStreamData(
|
return SendAudioStreamData(
|
||||||
{{kAudioStreamName,
|
{{kAudioStreamName,
|
||||||
MakePacket<Matrix>(std::move(audio_block))
|
MakePacket<Matrix>(std::move(audio_block))
|
||||||
|
|
|
@ -52,15 +52,10 @@ struct AudioClassifierOptions {
|
||||||
// 1) The audio clips mode for running classification on independent audio
|
// 1) The audio clips mode for running classification on independent audio
|
||||||
// clips.
|
// clips.
|
||||||
// 2) The audio stream mode for running classification on the audio stream,
|
// 2) The audio stream mode for running classification on the audio stream,
|
||||||
// such as from microphone. In this mode, the "sample_rate" below must be
|
// such as from microphone. In this mode, the "result_callback" below must
|
||||||
// provided, and the "result_callback" below must be specified to receive
|
// be specified to receive the classification results asynchronously.
|
||||||
// the classification results asynchronously.
|
|
||||||
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
|
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||||
|
|
||||||
// The sample rate of the input audios. Must be set when the running mode is
|
|
||||||
// set to RunningMode::AUDIO_STREAM.
|
|
||||||
double sample_rate = -1.0;
|
|
||||||
|
|
||||||
// The user-defined result callback for processing audio stream data.
|
// The user-defined result callback for processing audio stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::AUDIO_STREAM.
|
// to RunningMode::AUDIO_STREAM.
|
||||||
|
@ -160,15 +155,17 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
||||||
// The audio block is represented as a MediaPipe Matrix that has the number
|
// The audio block is represented as a MediaPipe Matrix that has the number
|
||||||
// of channels rows and the number of samples per channel columns. The audio
|
// of channels rows and the number of samples per channel columns. The audio
|
||||||
// data will be resampled, accumulated, and framed to the proper size for the
|
// data will be resampled, accumulated, and framed to the proper size for the
|
||||||
// underlying model to consume. It's required to provide a timestamp (in
|
// underlying model to consume. It's required to provide the corresponding
|
||||||
// milliseconds) to indicate the start time of the input audio block. The
|
// audio sample rate along with the input audio block as well as a timestamp
|
||||||
|
// (in milliseconds) to indicate the start time of the input audio block. The
|
||||||
// timestamps must be monotonically increasing.
|
// timestamps must be monotonically increasing.
|
||||||
//
|
//
|
||||||
// The input audio block may be longer than what the model is able to process
|
// The input audio block may be longer than what the model is able to process
|
||||||
// in a single inference. When this occurs, the input audio block is split
|
// in a single inference. When this occurs, the input audio block is split
|
||||||
// into multiple chunks. For this reason, the callback may be called multiple
|
// into multiple chunks. For this reason, the callback may be called multiple
|
||||||
// times (once per chunk) for each call to this function.
|
// times (once per chunk) for each call to this function.
|
||||||
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
|
absl::Status ClassifyAsync(mediapipe::Matrix audio_block,
|
||||||
|
double audio_sample_rate, int64 timestamp_ms);
|
||||||
|
|
||||||
// Shuts down the AudioClassifier when all works are done.
|
// Shuts down the AudioClassifier when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
|
@ -72,18 +72,6 @@ struct AudioClassifierOutputStreams {
|
||||||
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status SanityCheckOptions(
|
|
||||||
const proto::AudioClassifierGraphOptions& options) {
|
|
||||||
if (options.base_options().use_stream_mode() &&
|
|
||||||
!options.has_default_input_audio_sample_rate()) {
|
|
||||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
|
||||||
"In the streaming mode, the default input "
|
|
||||||
"audio sample rate must be set.",
|
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
|
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
|
||||||
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
|
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
|
||||||
const core::ModelResources& model_resources) {
|
const core::ModelResources& model_resources) {
|
||||||
|
@ -170,19 +158,12 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
const auto* model_resources,
|
const auto* model_resources,
|
||||||
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
|
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
const bool use_stream_mode =
|
|
||||||
sc->Options<proto::AudioClassifierGraphOptions>()
|
|
||||||
.base_options()
|
|
||||||
.use_stream_mode();
|
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_streams,
|
auto output_streams,
|
||||||
BuildAudioClassificationTask(
|
BuildAudioClassificationTask(
|
||||||
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
||||||
graph[Input<Matrix>(kAudioTag)],
|
graph[Input<Matrix>(kAudioTag)],
|
||||||
use_stream_mode
|
absl::make_optional(graph[Input<double>(kSampleRateTag)]), graph));
|
||||||
? absl::nullopt
|
|
||||||
: absl::make_optional(graph[Input<double>(kSampleRateTag)]),
|
|
||||||
graph));
|
|
||||||
output_streams.classifications >>
|
output_streams.classifications >>
|
||||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
output_streams.timestamped_classifications >>
|
output_streams.timestamped_classifications >>
|
||||||
|
@ -207,7 +188,6 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
const proto::AudioClassifierGraphOptions& task_options,
|
const proto::AudioClassifierGraphOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
|
||||||
const bool use_stream_mode = task_options.base_options().use_stream_mode();
|
const bool use_stream_mode = task_options.base_options().use_stream_mode();
|
||||||
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
||||||
// Checks that metadata is available.
|
// Checks that metadata is available.
|
||||||
|
|
|
@ -70,6 +70,8 @@ Matrix GetAudioData(absl::string_view filename) {
|
||||||
return matrix_mapping.matrix();
|
return matrix_mapping.matrix();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Compares the exact score values to capture unexpected
|
||||||
|
// changes in the inference pipeline.
|
||||||
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
|
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
|
||||||
int expected_num_categories = 521) {
|
int expected_num_categories = 521) {
|
||||||
EXPECT_EQ(result.size(), 5);
|
EXPECT_EQ(result.size(), 5);
|
||||||
|
@ -90,13 +92,15 @@ void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Compares the exact score values to capture unexpected
|
||||||
|
// changes in the inference pipeline.
|
||||||
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
||||||
EXPECT_GE(result.size(), 1);
|
EXPECT_GE(result.size(), 1);
|
||||||
EXPECT_LE(result.size(), 2);
|
EXPECT_LE(result.size(), 2);
|
||||||
// Check first result.
|
// Check the first result.
|
||||||
EXPECT_EQ(result[0].timestamp_ms, 0);
|
EXPECT_EQ(result[0].timestamp_ms, 0);
|
||||||
EXPECT_EQ(result[0].classifications.size(), 2);
|
EXPECT_EQ(result[0].classifications.size(), 2);
|
||||||
// Check first head.
|
// Check the first head.
|
||||||
EXPECT_EQ(result[0].classifications[0].head_index, 0);
|
EXPECT_EQ(result[0].classifications[0].head_index, 0);
|
||||||
EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
|
EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
|
||||||
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
|
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
|
||||||
|
@ -104,19 +108,19 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
||||||
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
|
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
|
||||||
"Environmental noise");
|
"Environmental noise");
|
||||||
EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
|
EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
|
||||||
// Check second head.
|
// Check the second head.
|
||||||
EXPECT_EQ(result[0].classifications[1].head_index, 1);
|
EXPECT_EQ(result[0].classifications[1].head_index, 1);
|
||||||
EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
|
EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
|
||||||
EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
|
EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
|
||||||
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
|
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
|
||||||
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
|
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
|
||||||
"Chestnut-crowned Antpitta");
|
"Chestnut-crowned Antpitta");
|
||||||
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
|
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.93f);
|
||||||
// Check second result, if present.
|
// Check the second result, if present.
|
||||||
if (result.size() == 2) {
|
if (result.size() == 2) {
|
||||||
EXPECT_EQ(result[1].timestamp_ms, 975);
|
EXPECT_EQ(result[1].timestamp_ms, 975);
|
||||||
EXPECT_EQ(result[1].classifications.size(), 2);
|
EXPECT_EQ(result[1].classifications.size(), 2);
|
||||||
// Check first head.
|
// Check the first head.
|
||||||
EXPECT_EQ(result[1].classifications[0].head_index, 0);
|
EXPECT_EQ(result[1].classifications[0].head_index, 0);
|
||||||
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
|
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
|
||||||
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
|
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
|
||||||
|
@ -124,7 +128,7 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
||||||
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
|
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
|
||||||
"Silence");
|
"Silence");
|
||||||
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
|
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
|
||||||
// Check second head.
|
// Check the second head.
|
||||||
EXPECT_EQ(result[1].classifications[1].head_index, 1);
|
EXPECT_EQ(result[1].classifications[1].head_index, 1);
|
||||||
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
|
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
|
||||||
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
|
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
|
||||||
|
@ -234,7 +238,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->sample_rate = 16000;
|
|
||||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||||
AudioClassifier::Create(std::move(options));
|
AudioClassifier::Create(std::move(options));
|
||||||
|
|
||||||
|
@ -266,25 +269,6 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
|
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
|
||||||
options->base_options.model_asset_path =
|
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
|
||||||
options->result_callback =
|
|
||||||
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
|
|
||||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
|
||||||
AudioClassifier::Create(std::move(options));
|
|
||||||
|
|
||||||
EXPECT_EQ(audio_classifier_or.status().code(),
|
|
||||||
absl::StatusCode::kInvalidArgument);
|
|
||||||
EXPECT_THAT(audio_classifier_or.status().message(),
|
|
||||||
HasSubstr("the sample rate must be specified"));
|
|
||||||
EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload),
|
|
||||||
Optional(absl::Cord(absl::StrCat(
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
|
||||||
}
|
|
||||||
|
|
||||||
class ClassifyTest : public tflite_shims::testing::Test {};
|
class ClassifyTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ClassifyTest, Succeeds) {
|
TEST_F(ClassifyTest, Succeeds) {
|
||||||
|
@ -493,7 +477,6 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||||
options->classifier_options.max_results = 1;
|
options->classifier_options.max_results = 1;
|
||||||
options->classifier_options.score_threshold = 0.3f;
|
options->classifier_options.score_threshold = 0.3f;
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->sample_rate = kSampleRateHz;
|
|
||||||
std::vector<AudioClassifierResult> outputs;
|
std::vector<AudioClassifierResult> outputs;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||||
|
@ -506,7 +489,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||||
int num_samples = std::min((int)(audio_buffer.cols() - start_col),
|
int num_samples = std::min((int)(audio_buffer.cols() - start_col),
|
||||||
kYamnetNumOfAudioSamples * 3);
|
kYamnetNumOfAudioSamples * 3);
|
||||||
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
|
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
|
||||||
audio_buffer.block(0, start_col, 1, num_samples),
|
audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
|
||||||
start_col * kMilliSecondsPerSecond / kSampleRateHz));
|
start_col * kMilliSecondsPerSecond / kSampleRateHz));
|
||||||
start_col += kYamnetNumOfAudioSamples * 3;
|
start_col += kYamnetNumOfAudioSamples * 3;
|
||||||
}
|
}
|
||||||
|
@ -523,7 +506,6 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||||
options->classifier_options.max_results = 1;
|
options->classifier_options.max_results = 1;
|
||||||
options->classifier_options.score_threshold = 0.3f;
|
options->classifier_options.score_threshold = 0.3f;
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->sample_rate = kSampleRateHz;
|
|
||||||
std::vector<AudioClassifierResult> outputs;
|
std::vector<AudioClassifierResult> outputs;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||||
|
@ -538,7 +520,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||||
std::min((int)(audio_buffer.cols() - start_col),
|
std::min((int)(audio_buffer.cols() - start_col),
|
||||||
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * 3);
|
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * 3);
|
||||||
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
|
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
|
||||||
audio_buffer.block(0, start_col, 1, num_samples),
|
audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
|
||||||
start_col * kMilliSecondsPerSecond / kSampleRateHz));
|
start_col * kMilliSecondsPerSecond / kSampleRateHz));
|
||||||
start_col += num_samples;
|
start_col += num_samples;
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,8 +72,39 @@ class BaseAudioTaskApi : public tasks::core::BaseTaskApi {
|
||||||
return runner_->Send(std::move(inputs));
|
return runner_->Send(std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks or sets the sample rate in the audio stream mode.
|
||||||
|
absl::Status CheckOrSetSampleRate(std::string sample_rate_stream_name,
|
||||||
|
double sample_rate) {
|
||||||
|
if (running_mode_ != RunningMode::AUDIO_STREAM) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrCat("Task is not initialized with the audio stream mode. "
|
||||||
|
"Current running mode:",
|
||||||
|
GetRunningModeName(running_mode_)),
|
||||||
|
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError);
|
||||||
|
}
|
||||||
|
if (default_sample_rate_ > 0) {
|
||||||
|
if (std::fabs(sample_rate - default_sample_rate_) >
|
||||||
|
std::numeric_limits<double>::epsilon()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrCat("The input audio sample rate: ", sample_rate,
|
||||||
|
" is inconsistent with the previously provided: ",
|
||||||
|
default_sample_rate_),
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
default_sample_rate_ = sample_rate;
|
||||||
|
MP_RETURN_IF_ERROR(runner_->Send(
|
||||||
|
{{sample_rate_stream_name, MakePacket<double>(default_sample_rate_)
|
||||||
|
.At(Timestamp::PreStream())}}));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
RunningMode running_mode_;
|
RunningMode running_mode_;
|
||||||
|
double default_sample_rate_ = -1.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -14,8 +14,10 @@
|
||||||
|
|
||||||
"""MediaPipe Tasks API."""
|
"""MediaPipe Tasks API."""
|
||||||
|
|
||||||
|
from . import audio
|
||||||
from . import components
|
from . import components
|
||||||
from . import core
|
from . import core
|
||||||
|
from . import text
|
||||||
from . import vision
|
from . import vision
|
||||||
|
|
||||||
BaseOptions = core.base_options.BaseOptions
|
BaseOptions = core.base_options.BaseOptions
|
||||||
|
|
41
mediapipe/tasks/python/audio/BUILD
Normal file
41
mediapipe/tasks/python/audio/BUILD
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "audio_classifier",
|
||||||
|
srcs = [
|
||||||
|
"audio_classifier.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||||
|
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
|
||||||
|
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
],
|
||||||
|
)
|
27
mediapipe/tasks/python/audio/__init__.py
Normal file
27
mediapipe/tasks/python/audio/__init__.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""MediaPipe Tasks Audio API."""
|
||||||
|
|
||||||
|
import mediapipe.tasks.python.audio.core
|
||||||
|
import mediapipe.tasks.python.audio.audio_classifier
|
||||||
|
|
||||||
|
AudioClassifier = audio_classifier.AudioClassifier
|
||||||
|
AudioClassifierOptions = audio_classifier.AudioClassifierOptions
|
||||||
|
RunningMode = core.audio_task_running_mode.AudioTaskRunningMode
|
||||||
|
|
||||||
|
# Remove unnecessary modules to avoid duplication in API docs.
|
||||||
|
del audio_classifier
|
||||||
|
del core
|
||||||
|
del mediapipe
|
280
mediapipe/tasks/python/audio/audio_classifier.py
Normal file
280
mediapipe/tasks/python/audio/audio_classifier.py
Normal file
|
@ -0,0 +1,280 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe audio classifier task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Callable, Mapping, List, Optional
|
||||||
|
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.python._framework_bindings import packet
|
||||||
|
from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
|
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||||
|
from mediapipe.tasks.python.audio.core import base_audio_task_api
|
||||||
|
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||||
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
|
from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
AudioClassifierResult = classification_result_module.ClassificationResult
|
||||||
|
_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions
|
||||||
|
_AudioData = audio_data_module.AudioData
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_ClassifierOptions = classifier_options_module.ClassifierOptions
|
||||||
|
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_AUDIO_IN_STREAM_NAME = 'audio_in'
|
||||||
|
_AUDIO_TAG = 'AUDIO'
|
||||||
|
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||||
|
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||||
|
_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in'
|
||||||
|
_SAMPLE_RATE_TAG = 'SAMPLE_RATE'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'
|
||||||
|
_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME = 'timestamped_classifications_out'
|
||||||
|
_TIMESTAMPED_CLASSIFICATIONS_TAG = 'TIMESTAMPED_CLASSIFICATIONS'
|
||||||
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AudioClassifierOptions:
|
||||||
|
"""Options for the audio classifier task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the audio classifier task.
|
||||||
|
running_mode: The running mode of the task. Default to the audio clips mode.
|
||||||
|
Audio classifier task has two running modes: 1) The audio clips mode for
|
||||||
|
running classification on independent audio clips. 2) The audio stream
|
||||||
|
mode for running classification on the audio stream, such as from
|
||||||
|
microphone. In this mode, the "result_callback" below must be specified
|
||||||
|
to receive the classification results asynchronously.
|
||||||
|
classifier_options: Options for configuring the classifier behavior, such as
|
||||||
|
score threshold, number of results, etc.
|
||||||
|
result_callback: The user-defined result callback for processing audio
|
||||||
|
stream data. The result callback should only be specified when the running
|
||||||
|
mode is set to the audio stream mode.
|
||||||
|
"""
|
||||||
|
base_options: _BaseOptions
|
||||||
|
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
|
||||||
|
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||||
|
result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _AudioClassifierGraphOptionsProto:
|
||||||
|
"""Generates an AudioClassifierOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
|
||||||
|
classifier_options_proto = self.classifier_options.to_pb2()
|
||||||
|
|
||||||
|
return _AudioClassifierGraphOptionsProto(
|
||||||
|
base_options=base_options_proto,
|
||||||
|
classifier_options=classifier_options_proto)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioClassifier(base_audio_task_api.BaseAudioTaskApi):
|
||||||
|
"""Class that performs audio classification on audio data."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'AudioClassifier':
|
||||||
|
"""Creates an `AudioClassifier` object from a TensorFlow Lite model and the default `AudioClassifierOptions`.
|
||||||
|
|
||||||
|
Note that the created `AudioClassifier` instance is in audio clips mode, for
|
||||||
|
classifying on independent audio clips.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`AudioClassifier` object that's created from the model file and the
|
||||||
|
default `AudioClassifierOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `AudioClassifier` object from the provided
|
||||||
|
file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = AudioClassifierOptions(
|
||||||
|
base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(cls,
|
||||||
|
options: AudioClassifierOptions) -> 'AudioClassifier':
|
||||||
|
"""Creates the `AudioClassifier` object from audio classifier options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the audio classifier task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`AudioClassifier` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `AudioClassifier` object from
|
||||||
|
`AudioClassifierOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
||||||
|
timestamp_ms = output_packets[
|
||||||
|
_CLASSIFICATIONS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND
|
||||||
|
if output_packets[_CLASSIFICATIONS_STREAM_NAME].is_empty():
|
||||||
|
options.result_callback(
|
||||||
|
AudioClassifierResult(classifications=[]), timestamp_ms)
|
||||||
|
return
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||||
|
options.result_callback(
|
||||||
|
AudioClassifierResult.create_from_pb2(classification_result_proto),
|
||||||
|
timestamp_ms)
|
||||||
|
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[
|
||||||
|
':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]),
|
||||||
|
':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME])
|
||||||
|
],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
|
||||||
|
':'.join([
|
||||||
|
_TIMESTAMPED_CLASSIFICATIONS_TAG,
|
||||||
|
_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME
|
||||||
|
])
|
||||||
|
],
|
||||||
|
task_options=options)
|
||||||
|
return cls(
|
||||||
|
# Audio tasks should not drop input audio due to flow limiting, which
|
||||||
|
# may cause data inconsistency.
|
||||||
|
task_info.generate_graph_config(enable_flow_limiting=False),
|
||||||
|
options.running_mode,
|
||||||
|
packets_callback if options.result_callback else None)
|
||||||
|
|
||||||
|
def classify(self, audio_clip: _AudioData) -> List[AudioClassifierResult]:
|
||||||
|
"""Performs audio classification on the provided audio clip.
|
||||||
|
|
||||||
|
The audio clip is represented as a MediaPipe AudioData. The method accepts
|
||||||
|
audio clips with various length and audio sample rate. It's required to
|
||||||
|
provide the corresponding audio sample rate within the `AudioData` object.
|
||||||
|
|
||||||
|
The input audio clip may be longer than what the model is able to process
|
||||||
|
in a single inference. When this occurs, the input audio clip is split into
|
||||||
|
multiple chunks starting at different timestamps. For this reason, this
|
||||||
|
function returns a vector of ClassificationResult objects, each associated
|
||||||
|
ith a timestamp corresponding to the start (in milliseconds) of the chunk
|
||||||
|
data that was classified, e.g:
|
||||||
|
|
||||||
|
ClassificationResult #0 (first chunk of data):
|
||||||
|
timestamp_ms: 0 (starts at 0ms)
|
||||||
|
classifications #0 (single head model):
|
||||||
|
category #0:
|
||||||
|
category_name: "Speech"
|
||||||
|
score: 0.6
|
||||||
|
category #1:
|
||||||
|
category_name: "Music"
|
||||||
|
score: 0.2
|
||||||
|
ClassificationResult #1 (second chunk of data):
|
||||||
|
timestamp_ms: 800 (starts at 800ms)
|
||||||
|
classifications #0 (single head model):
|
||||||
|
category #0:
|
||||||
|
category_name: "Speech"
|
||||||
|
score: 0.5
|
||||||
|
category #1:
|
||||||
|
category_name: "Silence"
|
||||||
|
score: 0.1
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_clip: MediaPipe AudioData.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `AudioClassifierResult` object that contains a list of
|
||||||
|
classification result objects, each associated with a timestamp
|
||||||
|
corresponding to the start (in milliseconds) of the chunk data that was
|
||||||
|
classified.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid, such as the sample
|
||||||
|
rate is not provided in the `AudioData` object.
|
||||||
|
RuntimeError: If audio classification failed to run.
|
||||||
|
"""
|
||||||
|
if not audio_clip.audio_format.sample_rate:
|
||||||
|
raise ValueError('Must provide the audio sample rate in audio data.')
|
||||||
|
output_packets = self._process_audio_clip({
|
||||||
|
_AUDIO_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_matrix(audio_clip.buffer, transpose=True),
|
||||||
|
_SAMPLE_RATE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_double(audio_clip.audio_format.sample_rate)
|
||||||
|
})
|
||||||
|
output_list = []
|
||||||
|
classification_result_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME])
|
||||||
|
for proto in classification_result_proto_list:
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(proto)
|
||||||
|
output_list.append(
|
||||||
|
AudioClassifierResult.create_from_pb2(classification_result_proto))
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
def classify_async(self, audio_block: _AudioData, timestamp_ms: int) -> None:
|
||||||
|
"""Sends audio data (a block in a continuous audio stream) to perform audio classification.
|
||||||
|
|
||||||
|
Only use this method when the AudioClassifier is created with the audio
|
||||||
|
stream running mode. The input timestamps should be monotonically increasing
|
||||||
|
for adjacent calls of this method. This method will return immediately after
|
||||||
|
the input audio data is accepted. The results will be available via the
|
||||||
|
`result_callback` provided in the `AudioClassifierOptions`. The
|
||||||
|
`classify_async` method is designed to process auido stream data such as
|
||||||
|
microphone input.
|
||||||
|
|
||||||
|
The input audio data may be longer than what the model is able to process
|
||||||
|
in a single inference. When this occurs, the input audio block is split
|
||||||
|
into multiple chunks. For this reason, the callback may be called multiple
|
||||||
|
times (once per chunk) for each call to this function.
|
||||||
|
|
||||||
|
The `result_callback` provides:
|
||||||
|
- An `AudioClassifierResult` object that contains a list of
|
||||||
|
classifications.
|
||||||
|
- The input timestamp in milliseconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_block: MediaPipe AudioData.
|
||||||
|
timestamp_ms: The timestamp of the input audio data in milliseconds.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the followings:
|
||||||
|
1) The sample rate is not provided in the `AudioData` object or the
|
||||||
|
provided sample rate is inconsisent with the previously recevied.
|
||||||
|
2) The current input timestamp is smaller than what the audio
|
||||||
|
classifier has already processed.
|
||||||
|
"""
|
||||||
|
if not audio_block.audio_format.sample_rate:
|
||||||
|
raise ValueError('Must provide the audio sample rate in audio data.')
|
||||||
|
if not self._default_sample_rate:
|
||||||
|
self._default_sample_rate = audio_block.audio_format.sample_rate
|
||||||
|
self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME,
|
||||||
|
self._default_sample_rate)
|
||||||
|
elif audio_block.audio_format.sample_rate != self._default_sample_rate:
|
||||||
|
raise ValueError(
|
||||||
|
f'The audio sample rate provided in audio data: '
|
||||||
|
f'{audio_block.audio_format.sample_rate} is inconsisent with '
|
||||||
|
f'the previously received: {self._default_sample_rate}.')
|
||||||
|
|
||||||
|
self._send_audio_stream_data({
|
||||||
|
_AUDIO_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_matrix(audio_block.buffer, transpose=True).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
})
|
|
@ -32,6 +32,7 @@ py_library(
|
||||||
":audio_task_running_mode",
|
":audio_task_running_mode",
|
||||||
"//mediapipe/framework:calculator_py_pb2",
|
"//mediapipe/framework:calculator_py_pb2",
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,14 +16,17 @@
|
||||||
from typing import Callable, Mapping, Optional
|
from typing import Callable, Mapping, Optional
|
||||||
|
|
||||||
from mediapipe.framework import calculator_pb2
|
from mediapipe.framework import calculator_pb2
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
|
from mediapipe.python._framework_bindings import timestamp as timestamp_module
|
||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
_TaskRunner = task_runner_module.TaskRunner
|
||||||
_Packet = packet_module.Packet
|
_Packet = packet_module.Packet
|
||||||
_RunningMode = running_mode_module.AudioTaskRunningMode
|
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||||
|
_Timestamp = timestamp_module.Timestamp
|
||||||
|
|
||||||
|
|
||||||
class BaseAudioTaskApi(object):
|
class BaseAudioTaskApi(object):
|
||||||
|
@ -59,6 +62,7 @@ class BaseAudioTaskApi(object):
|
||||||
'callback should not be provided.')
|
'callback should not be provided.')
|
||||||
self._runner = _TaskRunner.create(graph_config, packet_callback)
|
self._runner = _TaskRunner.create(graph_config, packet_callback)
|
||||||
self._running_mode = running_mode
|
self._running_mode = running_mode
|
||||||
|
self._default_sample_rate = None
|
||||||
|
|
||||||
def _process_audio_clip(
|
def _process_audio_clip(
|
||||||
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
|
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
|
||||||
|
@ -82,6 +86,27 @@ class BaseAudioTaskApi(object):
|
||||||
+ self._running_mode.name)
|
+ self._running_mode.name)
|
||||||
return self._runner.process(inputs)
|
return self._runner.process(inputs)
|
||||||
|
|
||||||
|
def _set_sample_rate(self, sample_rate_stream_name: str,
|
||||||
|
sample_rate: float) -> None:
|
||||||
|
"""An asynchronous method to set audio sample rate in the audio stream mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate_stream_name: The audio sample rate stream name.
|
||||||
|
sample_rate: The audio sample rate.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the task's running mode is not set to the audio stream
|
||||||
|
mode.
|
||||||
|
"""
|
||||||
|
if self._running_mode != _RunningMode.AUDIO_STREAM:
|
||||||
|
raise ValueError(
|
||||||
|
'Task is not initialized with the audio stream mode. Current running mode:'
|
||||||
|
+ self._running_mode.name)
|
||||||
|
self._runner.send({
|
||||||
|
sample_rate_stream_name:
|
||||||
|
packet_creator.create_double(sample_rate).at(_Timestamp.PRESTREAM)
|
||||||
|
})
|
||||||
|
|
||||||
def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
|
def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
|
||||||
"""An asynchronous method to send audio stream data to the runner.
|
"""An asynchronous method to send audio stream data to the runner.
|
||||||
|
|
||||||
|
|
|
@ -95,6 +95,16 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "classification_result",
|
||||||
|
srcs = ["classification_result.py"],
|
||||||
|
deps = [
|
||||||
|
":category",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "embeddings",
|
name = "embeddings",
|
||||||
srcs = ["embeddings.py"],
|
srcs = ["embeddings.py"],
|
||||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AudioFormat:
|
class AudioDataFormat:
|
||||||
"""Audio format metadata.
|
"""Audio format metadata.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -35,8 +35,10 @@ class AudioData(object):
|
||||||
"""MediaPipe Tasks' audio container."""
|
"""MediaPipe Tasks' audio container."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, buffer_length: int,
|
self,
|
||||||
audio_format: AudioFormat = AudioFormat()) -> None:
|
buffer_length: int,
|
||||||
|
audio_format: AudioDataFormat = AudioDataFormat()
|
||||||
|
) -> None:
|
||||||
"""Initializes the `AudioData` object.
|
"""Initializes the `AudioData` object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -113,14 +115,14 @@ class AudioData(object):
|
||||||
"""
|
"""
|
||||||
obj = cls(
|
obj = cls(
|
||||||
buffer_length=src.shape[0],
|
buffer_length=src.shape[0],
|
||||||
audio_format=AudioFormat(
|
audio_format=AudioDataFormat(
|
||||||
num_channels=1 if len(src.shape) == 1 else src.shape[1],
|
num_channels=1 if len(src.shape) == 1 else src.shape[1],
|
||||||
sample_rate=sample_rate))
|
sample_rate=sample_rate))
|
||||||
obj.load_from_array(src)
|
obj.load_from_array(src)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_format(self) -> AudioFormat:
|
def audio_format(self) -> AudioDataFormat:
|
||||||
"""Gets the audio format of the audio."""
|
"""Gets the audio format of the audio."""
|
||||||
return self._audio_format
|
return self._audio_format
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Classifications data class."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_ClassificationsProto = classifications_pb2.Classifications
|
||||||
|
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Classifications:
|
||||||
|
"""Represents the classification results for a given classifier head.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
categories: The array of predicted categories, usually sorted by descending
|
||||||
|
scores (e.g. from high to low probability).
|
||||||
|
head_index: The index of the classifier head these categories refer to. This
|
||||||
|
is useful for multi-head models.
|
||||||
|
head_name: The name of the classifier head, which is the corresponding
|
||||||
|
tensor metadata name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
categories: List[category_module.Category]
|
||||||
|
head_index: int
|
||||||
|
head_name: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||||
|
"""Creates a `Classifications` object from the given protobuf object."""
|
||||||
|
categories = []
|
||||||
|
for entry in pb2_obj.classification_list.classification:
|
||||||
|
categories.append(
|
||||||
|
category_module.Category(
|
||||||
|
index=entry.index,
|
||||||
|
score=entry.score,
|
||||||
|
display_name=entry.display_name,
|
||||||
|
category_name=entry.label))
|
||||||
|
|
||||||
|
return Classifications(
|
||||||
|
categories=categories,
|
||||||
|
head_index=pb2_obj.head_index,
|
||||||
|
head_name=pb2_obj.head_name)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ClassificationResult:
|
||||||
|
"""Contains the classification results of a model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
classifications: A list of `Classifications` objects, each for a head of the
|
||||||
|
model.
|
||||||
|
timestamp_ms: The optional timestamp (in milliseconds) of the start of the
|
||||||
|
chunk of data corresponding to these results. This is only used for
|
||||||
|
classification on time series (e.g. audio classification). In these use
|
||||||
|
cases, the amount of data to process might exceed the maximum size that
|
||||||
|
the model can process: to solve this, the input data is split into
|
||||||
|
multiple chunks starting at different timestamps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
classifications: List[Classifications]
|
||||||
|
timestamp_ms: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
|
||||||
|
"""Creates a `ClassificationResult` object from the given protobuf object.
|
||||||
|
"""
|
||||||
|
return ClassificationResult(
|
||||||
|
classifications=[
|
||||||
|
Classifications.create_from_pb2(classification)
|
||||||
|
for classification in pb2_obj.classifications
|
||||||
|
],
|
||||||
|
timestamp_ms=pb2_obj.timestamp_ms)
|
37
mediapipe/tasks/python/test/audio/BUILD
Normal file
37
mediapipe/tasks/python/test/audio/BUILD
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "audio_classifier_test",
|
||||||
|
srcs = ["audio_classifier_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/audio:test_audio_clips",
|
||||||
|
"//mediapipe/tasks/testdata/audio:test_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/audio:audio_classifier",
|
||||||
|
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||||
|
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/tasks/python/test/audio/__init__.py
Normal file
13
mediapipe/tasks/python/test/audio/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
381
mediapipe/tasks/python/test/audio/audio_classifier_test.py
Normal file
381
mediapipe/tasks/python/test/audio/audio_classifier_test.py
Normal file
|
@ -0,0 +1,381 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for audio classifier."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List, Tuple
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.audio import audio_classifier
|
||||||
|
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||||
|
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||||
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_AudioClassifier = audio_classifier.AudioClassifier
|
||||||
|
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
|
||||||
|
_AudioClassifierResult = classification_result_module.ClassificationResult
|
||||||
|
_AudioData = audio_data_module.AudioData
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
|
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||||
|
|
||||||
|
_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
|
||||||
|
_YAMNET_MODEL_SAMPLE_RATE = 16000
|
||||||
|
_TWO_HEADS_MODEL_FILE = 'two_heads.tflite'
|
||||||
|
_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav'
|
||||||
|
_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav'
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio'
|
||||||
|
_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav'
|
||||||
|
_TWO_HEADS_WAV_44K_MONO = 'two_heads_44100_hz_mono.wav'
|
||||||
|
_YAMNET_NUM_OF_SAMPLES = 15600
|
||||||
|
_MILLSECONDS_PER_SECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class AudioClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.yamnet_model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE))
|
||||||
|
self.two_heads_model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _TWO_HEADS_MODEL_FILE))
|
||||||
|
|
||||||
|
def _read_wav_file(self, file_name) -> _AudioData:
|
||||||
|
sample_rate, buffer = wavfile.read(
|
||||||
|
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
|
||||||
|
return _AudioData.create_from_array(
|
||||||
|
buffer.astype(float) / np.iinfo(np.int16).max, sample_rate)
|
||||||
|
|
||||||
|
def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]:
|
||||||
|
sample_rate, buffer = wavfile.read(
|
||||||
|
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
|
||||||
|
audio_data_list = []
|
||||||
|
start = 0
|
||||||
|
step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE
|
||||||
|
while start < len(buffer):
|
||||||
|
end = min(start + (int)(step_size), len(buffer))
|
||||||
|
audio_data_list.append((_AudioData.create_from_array(
|
||||||
|
buffer[start:end].astype(float) / np.iinfo(np.int16).max,
|
||||||
|
sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND)))
|
||||||
|
start = end
|
||||||
|
return audio_data_list
|
||||||
|
|
||||||
|
# TODO: Compares the exact score values to capture unexpected
|
||||||
|
# changes in the inference pipeline.
|
||||||
|
def _check_yamnet_result(
|
||||||
|
self,
|
||||||
|
classification_result_list: List[_AudioClassifierResult],
|
||||||
|
expected_num_categories=521):
|
||||||
|
self.assertLen(classification_result_list, 5)
|
||||||
|
for idx, timestamp in enumerate([0, 975, 1950, 2925]):
|
||||||
|
classification_result = classification_result_list[idx]
|
||||||
|
self.assertEqual(classification_result.timestamp_ms, timestamp)
|
||||||
|
self.assertLen(classification_result.classifications, 1)
|
||||||
|
classifcation = classification_result.classifications[0]
|
||||||
|
self.assertEqual(classifcation.head_index, 0)
|
||||||
|
self.assertEqual(classifcation.head_name, 'scores')
|
||||||
|
self.assertLen(classifcation.categories, expected_num_categories)
|
||||||
|
audio_category = classifcation.categories[0]
|
||||||
|
self.assertEqual(audio_category.index, 0)
|
||||||
|
self.assertEqual(audio_category.category_name, 'Speech')
|
||||||
|
self.assertGreater(audio_category.score, 0.9)
|
||||||
|
|
||||||
|
# TODO: Compares the exact score values to capture unexpected
|
||||||
|
# changes in the inference pipeline.
|
||||||
|
def _check_two_heads_result(
|
||||||
|
self,
|
||||||
|
classification_result_list: List[_AudioClassifierResult],
|
||||||
|
first_head_expected_num_categories=521,
|
||||||
|
second_head_expected_num_categories=5):
|
||||||
|
self.assertGreaterEqual(len(classification_result_list), 1)
|
||||||
|
self.assertLessEqual(len(classification_result_list), 2)
|
||||||
|
# Checks the first result.
|
||||||
|
classification_result = classification_result_list[0]
|
||||||
|
self.assertEqual(classification_result.timestamp_ms, 0)
|
||||||
|
self.assertLen(classification_result.classifications, 2)
|
||||||
|
# Checks the first head.
|
||||||
|
yamnet_classifcation = classification_result.classifications[0]
|
||||||
|
self.assertEqual(yamnet_classifcation.head_index, 0)
|
||||||
|
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
|
||||||
|
self.assertLen(yamnet_classifcation.categories,
|
||||||
|
first_head_expected_num_categories)
|
||||||
|
# Checks the second head.
|
||||||
|
yamnet_category = yamnet_classifcation.categories[0]
|
||||||
|
self.assertEqual(yamnet_category.index, 508)
|
||||||
|
self.assertEqual(yamnet_category.category_name, 'Environmental noise')
|
||||||
|
self.assertGreater(yamnet_category.score, 0.5)
|
||||||
|
bird_classifcation = classification_result.classifications[1]
|
||||||
|
self.assertEqual(bird_classifcation.head_index, 1)
|
||||||
|
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
|
||||||
|
self.assertLen(bird_classifcation.categories,
|
||||||
|
second_head_expected_num_categories)
|
||||||
|
bird_category = bird_classifcation.categories[0]
|
||||||
|
self.assertEqual(bird_category.index, 4)
|
||||||
|
self.assertEqual(bird_category.category_name, 'Chestnut-crowned Antpitta')
|
||||||
|
self.assertGreater(bird_category.score, 0.93)
|
||||||
|
# Checks the second result, if present.
|
||||||
|
if len(classification_result_list) == 2:
|
||||||
|
classification_result = classification_result_list[1]
|
||||||
|
self.assertEqual(classification_result.timestamp_ms, 975)
|
||||||
|
self.assertLen(classification_result.classifications, 2)
|
||||||
|
# Checks the first head.
|
||||||
|
yamnet_classifcation = classification_result.classifications[0]
|
||||||
|
self.assertEqual(yamnet_classifcation.head_index, 0)
|
||||||
|
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
|
||||||
|
self.assertLen(yamnet_classifcation.categories,
|
||||||
|
first_head_expected_num_categories)
|
||||||
|
yamnet_category = yamnet_classifcation.categories[0]
|
||||||
|
self.assertEqual(yamnet_category.index, 494)
|
||||||
|
self.assertEqual(yamnet_category.category_name, 'Silence')
|
||||||
|
self.assertGreater(yamnet_category.score, 0.9)
|
||||||
|
bird_classifcation = classification_result.classifications[1]
|
||||||
|
self.assertEqual(bird_classifcation.head_index, 1)
|
||||||
|
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
|
||||||
|
self.assertLen(bird_classifcation.categories,
|
||||||
|
second_head_expected_num_categories)
|
||||||
|
# Checks the second head.
|
||||||
|
bird_category = bird_classifcation.categories[0]
|
||||||
|
self.assertEqual(bird_category.index, 1)
|
||||||
|
self.assertEqual(bird_category.category_name, 'White-breasted Wood-Wren')
|
||||||
|
self.assertGreater(bird_category.score, 0.99)
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _AudioClassifier.create_from_model_path(
|
||||||
|
self.yamnet_model_path) as classifier:
|
||||||
|
self.assertIsInstance(classifier, _AudioClassifier)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
with _AudioClassifier.create_from_options(
|
||||||
|
_AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(
|
||||||
|
model_asset_path=self.yamnet_model_path))) as classifier:
|
||||||
|
self.assertIsInstance(classifier, _AudioClassifier)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
# Invalid empty model path.
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
r"ExternalFile must specify at least one of 'file_content', "
|
||||||
|
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||||
|
base_options = _BaseOptions(model_asset_path='')
|
||||||
|
options = _AudioClassifierOptions(base_options=base_options)
|
||||||
|
_AudioClassifier.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.yamnet_model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _AudioClassifierOptions(base_options=base_options)
|
||||||
|
classifier = _AudioClassifier.create_from_options(options)
|
||||||
|
self.assertIsInstance(classifier, _AudioClassifier)
|
||||||
|
|
||||||
|
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
|
||||||
|
def test_classify_with_yamnet_model(self, audio_file):
|
||||||
|
with _AudioClassifier.create_from_model_path(
|
||||||
|
self.yamnet_model_path) as classifier:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_yamnet_result(classification_result_list)
|
||||||
|
|
||||||
|
def test_classify_with_yamnet_model_and_inputs_at_different_sample_rates(
|
||||||
|
self):
|
||||||
|
with _AudioClassifier.create_from_model_path(
|
||||||
|
self.yamnet_model_path) as classifier:
|
||||||
|
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO]:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_yamnet_result(classification_result_list)
|
||||||
|
|
||||||
|
def test_max_result_options(self):
|
||||||
|
with _AudioClassifier.create_from_options(
|
||||||
|
_AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
classifier_options=_ClassifierOptions(
|
||||||
|
max_results=1))) as classifier:
|
||||||
|
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_yamnet_result(
|
||||||
|
classification_result_list, expected_num_categories=1)
|
||||||
|
|
||||||
|
def test_score_threshold_options(self):
|
||||||
|
with _AudioClassifier.create_from_options(
|
||||||
|
_AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
classifier_options=_ClassifierOptions(
|
||||||
|
score_threshold=0.9))) as classifier:
|
||||||
|
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_yamnet_result(
|
||||||
|
classification_result_list, expected_num_categories=1)
|
||||||
|
|
||||||
|
def test_allow_list_option(self):
|
||||||
|
with _AudioClassifier.create_from_options(
|
||||||
|
_AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
classifier_options=_ClassifierOptions(
|
||||||
|
category_allowlist=['Speech']))) as classifier:
|
||||||
|
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_yamnet_result(
|
||||||
|
classification_result_list, expected_num_categories=1)
|
||||||
|
|
||||||
|
def test_combined_allowlist_and_denylist(self):
|
||||||
|
# Fails with combined allowlist and denylist
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
r'`category_allowlist` and `category_denylist` are mutually '
|
||||||
|
r'exclusive options.'):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
classifier_options=_ClassifierOptions(
|
||||||
|
category_allowlist=['foo'], category_denylist=['bar']))
|
||||||
|
with _AudioClassifier.create_from_options(options) as unused_classifier:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.parameters((_TWO_HEADS_WAV_16K_MONO),
|
||||||
|
(_TWO_HEADS_WAV_44K_MONO))
|
||||||
|
def test_classify_with_two_heads_model_and_inputs_at_different_sample_rates(
|
||||||
|
self, audio_file):
|
||||||
|
with _AudioClassifier.create_from_model_path(
|
||||||
|
self.two_heads_model_path) as classifier:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_two_heads_result(classification_result_list)
|
||||||
|
|
||||||
|
def test_classify_with_two_heads_model(self):
|
||||||
|
with _AudioClassifier.create_from_model_path(
|
||||||
|
self.two_heads_model_path) as classifier:
|
||||||
|
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_two_heads_result(classification_result_list)
|
||||||
|
|
||||||
|
def test_classify_with_two_heads_model_with_max_results(self):
|
||||||
|
with _AudioClassifier.create_from_options(
|
||||||
|
_AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(
|
||||||
|
model_asset_path=self.two_heads_model_path),
|
||||||
|
classifier_options=_ClassifierOptions(
|
||||||
|
max_results=1))) as classifier:
|
||||||
|
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
|
||||||
|
classification_result_list = classifier.classify(
|
||||||
|
self._read_wav_file(audio_file))
|
||||||
|
self._check_two_heads_result(classification_result_list, 1, 1)
|
||||||
|
|
||||||
|
def test_missing_sample_rate_in_audio_clips_mode(self):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'Must provide the audio sample rate'):
|
||||||
|
with _AudioClassifier.create_from_options(options) as classifier:
|
||||||
|
classifier.classify(_AudioData(buffer_length=100))
|
||||||
|
|
||||||
|
def test_missing_sample_rate_in_audio_stream_mode(self):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'provide the audio sample rate in audio data'):
|
||||||
|
with _AudioClassifier.create_from_options(options) as classifier:
|
||||||
|
classifier.classify(_AudioData(buffer_length=100))
|
||||||
|
|
||||||
|
def test_missing_result_callback(self):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_STREAM)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback must be provided'):
|
||||||
|
with _AudioClassifier.create_from_options(options) as unused_classifier:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_illegal_result_callback(self):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_CLIPS,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback should not be provided'):
|
||||||
|
with _AudioClassifier.create_from_options(options) as unused_classifier:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_calling_classify_in_audio_stream_mode(self):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _AudioClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the audio clips mode'):
|
||||||
|
classifier.classify(self._read_wav_file(_SPEECH_WAV_16K_MONO))
|
||||||
|
|
||||||
|
def test_calling_classify_async_in_audio_clips_mode(self):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
|
||||||
|
with _AudioClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'not initialized with the audio stream mode'):
|
||||||
|
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
|
||||||
|
|
||||||
|
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _AudioClassifier.create_from_options(options) as classifier:
|
||||||
|
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
|
||||||
|
|
||||||
|
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
|
||||||
|
def test_classify_async(self, audio_file):
|
||||||
|
classification_result_list = []
|
||||||
|
|
||||||
|
def save_result(result: _AudioClassifierResult, timestamp_ms: int):
|
||||||
|
result.timestamp_ms = timestamp_ms
|
||||||
|
classification_result_list.append(result)
|
||||||
|
|
||||||
|
options = _AudioClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||||
|
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||||
|
classifier_options=_ClassifierOptions(max_results=1),
|
||||||
|
result_callback=save_result)
|
||||||
|
classifier = _AudioClassifier.create_from_options(options)
|
||||||
|
audio_data_list = self._read_wav_file_as_stream(audio_file)
|
||||||
|
for audio_data, timestamp_ms in audio_data_list:
|
||||||
|
classifier.classify_async(audio_data, timestamp_ms)
|
||||||
|
classifier.close()
|
||||||
|
self._check_yamnet_result(
|
||||||
|
classification_result_list, expected_num_categories=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -11,3 +11,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""MediaPipe Tasks Text API."""
|
||||||
|
|
||||||
|
import mediapipe.tasks.python.text.text_classifier
|
||||||
|
|
||||||
|
TextClassifier = text_classifier.TextClassifier
|
||||||
|
TextClassifierOptions = text_classifier.TextClassifierOptions
|
||||||
|
|
||||||
|
# Remove unnecessary modules to avoid duplication in API docs.
|
||||||
|
del mediapipe
|
||||||
|
del text_classifier
|
||||||
|
|
11
mediapipe/tasks/web/audio/BUILD
Normal file
11
mediapipe/tasks/web/audio/BUILD
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# This contains the MediaPipe Audio Tasks.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "audio_lib",
|
||||||
|
srcs = ["index.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/web/audio/audio_classifier",
|
||||||
|
],
|
||||||
|
)
|
20
mediapipe/tasks/web/audio/index.ts
Normal file
20
mediapipe/tasks/web/audio/index.ts
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Audio Classifier
|
||||||
|
export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_options';
|
||||||
|
export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_result';
|
||||||
|
export * from '../../../tasks/web/audio/audio_classifier/audio_classifier';
|
11
mediapipe/tasks/web/text/BUILD
Normal file
11
mediapipe/tasks/web/text/BUILD
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# This contains the MediaPipe Text Tasks.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "text_lib",
|
||||||
|
srcs = ["index.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/web/text/text_classifier",
|
||||||
|
],
|
||||||
|
)
|
20
mediapipe/tasks/web/text/index.ts
Normal file
20
mediapipe/tasks/web/text/index.ts
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Text Classifier
|
||||||
|
export * from '../../../tasks/web/text/text_classifier/text_classifier_options';
|
||||||
|
export * from '../../../tasks/web/text/text_classifier/text_classifier_result';
|
||||||
|
export * from '../../../tasks/web/text/text_classifier/text_classifier';
|
13
mediapipe/tasks/web/vision/BUILD
Normal file
13
mediapipe/tasks/web/vision/BUILD
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# This contains the MediaPipe Vision Tasks.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "vision_lib",
|
||||||
|
srcs = ["index.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/web/vision/gesture_recognizer",
|
||||||
|
"//mediapipe/tasks/web/vision/image_classifier",
|
||||||
|
"//mediapipe/tasks/web/vision/object_detector",
|
||||||
|
],
|
||||||
|
)
|
30
mediapipe/tasks/web/vision/index.ts
Normal file
30
mediapipe/tasks/web/vision/index.ts
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Image Classifier
|
||||||
|
export * from '../../../tasks/web/vision/image_classifier/image_classifier_options';
|
||||||
|
export * from '../../../tasks/web/vision/image_classifier/image_classifier_result';
|
||||||
|
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||||
|
|
||||||
|
// Gesture Recognizer
|
||||||
|
export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_options';
|
||||||
|
export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_result';
|
||||||
|
export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
|
||||||
|
|
||||||
|
// Object Detector
|
||||||
|
export * from '../../../tasks/web/vision/object_detector/object_detector_options';
|
||||||
|
export * from '../../../tasks/web/vision/object_detector/object_detector_result';
|
||||||
|
export * from '../../../tasks/web/vision/object_detector/object_detector';
|
Loading…
Reference in New Issue
Block a user