Refactor classifiers public APIs for namespace and naming consistency.
PiperOrigin-RevId: 477705647
This commit is contained in:
parent
554e2a9d69
commit
a8ca669f05
|
@ -33,7 +33,7 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||||
|
@ -60,7 +60,7 @@ cc_library(
|
||||||
":audio_classifier_graph",
|
":audio_classifier_graph",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||||
|
@ -33,6 +33,8 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
namespace audio_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kAudioStreamName[] = "audio_in";
|
constexpr char kAudioStreamName[] = "audio_in";
|
||||||
|
@ -42,16 +44,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||||
constexpr char kSampleRateName[] = "sample_rate_in";
|
constexpr char kSampleRateName[] = "sample_rate_in";
|
||||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.audio.AudioClassifierGraph";
|
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
using AudioClassifierOptionsProto =
|
|
||||||
audio_classifier::proto::AudioClassifierOptions;
|
|
||||||
|
|
||||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
// "mediapipe.tasks.audio.AudioClassifierGraph".
|
// type "AudioClassifierGraph".
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
std::unique_ptr<AudioClassifierOptionsProto> options_proto) {
|
std::unique_ptr<proto::AudioClassifierGraphOptions> options_proto) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
||||||
|
@ -59,7 +58,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
||||||
subgraph.In(kSampleRateTag);
|
subgraph.In(kSampleRateTag);
|
||||||
}
|
}
|
||||||
subgraph.GetOptions<AudioClassifierOptionsProto>().Swap(options_proto.get());
|
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
||||||
|
options_proto.get());
|
||||||
subgraph.Out(kClassificationResultTag)
|
subgraph.Out(kClassificationResultTag)
|
||||||
.SetName(kClassificationResultStreamName) >>
|
.SetName(kClassificationResultStreamName) >>
|
||||||
graph.Out(kClassificationResultTag);
|
graph.Out(kClassificationResultTag);
|
||||||
|
@ -67,10 +67,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts the user-facing AudioClassifierOptions struct to the internal
|
// Converts the user-facing AudioClassifierOptions struct to the internal
|
||||||
// AudioClassifierOptions proto.
|
// AudioClassifierGraphOptions proto.
|
||||||
std::unique_ptr<AudioClassifierOptionsProto>
|
std::unique_ptr<proto::AudioClassifierGraphOptions>
|
||||||
ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
||||||
auto options_proto = std::make_unique<AudioClassifierOptionsProto>();
|
auto options_proto = std::make_unique<proto::AudioClassifierGraphOptions>();
|
||||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
@ -119,7 +119,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return core::AudioTaskApiFactory::Create<AudioClassifier,
|
return core::AudioTaskApiFactory::Create<AudioClassifier,
|
||||||
AudioClassifierOptionsProto>(
|
proto::AudioClassifierGraphOptions>(
|
||||||
CreateGraphConfig(std::move(options_proto)),
|
CreateGraphConfig(std::move(options_proto)),
|
||||||
std::move(options->base_options.op_resolver), options->running_mode,
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
|
@ -140,6 +140,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace audio_classifier
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
namespace audio_classifier {
|
||||||
|
|
||||||
// The options for configuring a mediapipe audio classifier task.
|
// The options for configuring a mediapipe audio classifier task.
|
||||||
struct AudioClassifierOptions {
|
struct AudioClassifierOptions {
|
||||||
|
@ -162,6 +163,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace audio_classifier
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
namespace audio_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -60,10 +61,9 @@ constexpr char kPacketTag[] = "PACKET";
|
||||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
using AudioClassifierOptionsProto =
|
|
||||||
audio_classifier::proto::AudioClassifierOptions;
|
|
||||||
|
|
||||||
absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) {
|
absl::Status SanityCheckOptions(
|
||||||
|
const proto::AudioClassifierGraphOptions& options) {
|
||||||
if (options.base_options().use_stream_mode() &&
|
if (options.base_options().use_stream_mode() &&
|
||||||
!options.has_default_input_audio_sample_rate()) {
|
!options.has_default_input_audio_sample_rate()) {
|
||||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||||
|
@ -111,7 +111,7 @@ void ConfigureAudioToTensorCalculator(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification.
|
// An "AudioClassifierGraph" performs audio classification.
|
||||||
// - Accepts CPU audio buffer and outputs classification results on CPU.
|
// - Accepts CPU audio buffer and outputs classification results on CPU.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
|
@ -129,12 +129,12 @@ void ConfigureAudioToTensorCalculator(
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.audio.AudioClassifierGraph"
|
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
|
||||||
// input_stream: "AUDIO:audio_in"
|
// input_stream: "AUDIO:audio_in"
|
||||||
// input_stream: "SAMPLE_RATE:sample_rate_in"
|
// input_stream: "SAMPLE_RATE:sample_rate_in"
|
||||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
|
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
|
||||||
// {
|
// {
|
||||||
// base_options {
|
// base_options {
|
||||||
// model_asset {
|
// model_asset {
|
||||||
|
@ -152,16 +152,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
ASSIGN_OR_RETURN(
|
||||||
CreateModelResources<AudioClassifierOptionsProto>(sc));
|
const auto* model_resources,
|
||||||
|
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
const bool use_stream_mode = sc->Options<AudioClassifierOptionsProto>()
|
const bool use_stream_mode =
|
||||||
.base_options()
|
sc->Options<proto::AudioClassifierGraphOptions>()
|
||||||
.use_stream_mode();
|
.base_options()
|
||||||
|
.use_stream_mode();
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto classification_result_out,
|
auto classification_result_out,
|
||||||
BuildAudioClassificationTask(
|
BuildAudioClassificationTask(
|
||||||
sc->Options<AudioClassifierOptionsProto>(), *model_resources,
|
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
||||||
graph[Input<Matrix>(kAudioTag)],
|
graph[Input<Matrix>(kAudioTag)],
|
||||||
use_stream_mode
|
use_stream_mode
|
||||||
? absl::nullopt
|
? absl::nullopt
|
||||||
|
@ -178,14 +180,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
// buffer (mediapipe::Matrix) and the corresponding sample rate (double) as
|
// buffer (mediapipe::Matrix) and the corresponding sample rate (double) as
|
||||||
// the inputs and returns one classification result per input audio buffer.
|
// the inputs and returns one classification result per input audio buffer.
|
||||||
//
|
//
|
||||||
// task_options: the mediapipe tasks AudioClassifierOptions proto.
|
// task_options: the mediapipe tasks AudioClassifierGraphOptions proto.
|
||||||
// model_resources: the ModelSources object initialized from an audio
|
// model_resources: the ModelSources object initialized from an audio
|
||||||
// classifier model file with model metadata.
|
// classifier model file with model metadata.
|
||||||
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
|
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
|
||||||
// sample_rate_in: (double) optional stream of the input audio sample rate.
|
// sample_rate_in: (double) optional stream of the input audio sample rate.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
|
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
|
||||||
const AudioClassifierOptionsProto& task_options,
|
const proto::AudioClassifierGraphOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||||
|
@ -257,8 +259,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph);
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph);
|
||||||
|
|
||||||
|
} // namespace audio_classifier
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
namespace audio_classifier {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::absl::StatusOr;
|
using ::absl::StatusOr;
|
||||||
|
@ -557,6 +558,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace audio_classifier
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "audio_classifier_options_proto",
|
name = "audio_classifier_graph_options_proto",
|
||||||
srcs = ["audio_classifier_options.proto"],
|
srcs = ["audio_classifier_graph_options.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
|
|
@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message AudioClassifierOptions {
|
message AudioClassifierGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional AudioClassifierOptions ext = 451755788;
|
optional AudioClassifierGraphOptions ext = 451755788;
|
||||||
}
|
}
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, etc.
|
// model file with metadata, accelerator options, etc.
|
|
@ -33,7 +33,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -61,7 +61,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
|
|
|
@ -36,11 +36,12 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/utils.h"
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
namespace image_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -52,12 +53,10 @@ constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kNormRectName[] = "norm_rect_in";
|
constexpr char kNormRectName[] = "norm_rect_in";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.ImageClassifierGraph";
|
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
using ::mediapipe::tasks::core::PacketMap;
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
using ImageClassifierOptionsProto =
|
|
||||||
image_classifier::proto::ImageClassifierOptions;
|
|
||||||
|
|
||||||
// Builds a NormalizedRect covering the entire image.
|
// Builds a NormalizedRect covering the entire image.
|
||||||
NormalizedRect BuildFullImageNormRect() {
|
NormalizedRect BuildFullImageNormRect() {
|
||||||
|
@ -70,17 +69,17 @@ NormalizedRect BuildFullImageNormRect() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||||
// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the
|
// type "ImageClassifierGraph". If the task is running in the live stream mode,
|
||||||
// live stream mode, a "FlowLimiterCalculator" will be added to limit the number
|
// a "FlowLimiterCalculator" will be added to limit the number of frames in
|
||||||
// of frames in flight.
|
// flight.
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
std::unique_ptr<ImageClassifierOptionsProto> options_proto,
|
std::unique_ptr<proto::ImageClassifierGraphOptions> options_proto,
|
||||||
bool enable_flow_limiting) {
|
bool enable_flow_limiting) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
graph.In(kNormRectTag).SetName(kNormRectName);
|
graph.In(kNormRectTag).SetName(kNormRectName);
|
||||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
task_subgraph.GetOptions<ImageClassifierOptionsProto>().Swap(
|
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
|
||||||
options_proto.get());
|
options_proto.get());
|
||||||
task_subgraph.Out(kClassificationResultTag)
|
task_subgraph.Out(kClassificationResultTag)
|
||||||
.SetName(kClassificationResultStreamName) >>
|
.SetName(kClassificationResultStreamName) >>
|
||||||
|
@ -98,10 +97,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts the user-facing ImageClassifierOptions struct to the internal
|
// Converts the user-facing ImageClassifierOptions struct to the internal
|
||||||
// ImageClassifierOptions proto.
|
// ImageClassifierGraphOptions proto.
|
||||||
std::unique_ptr<ImageClassifierOptionsProto>
|
std::unique_ptr<proto::ImageClassifierGraphOptions>
|
||||||
ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
|
ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
|
||||||
auto options_proto = std::make_unique<ImageClassifierOptionsProto>();
|
auto options_proto = std::make_unique<proto::ImageClassifierGraphOptions>();
|
||||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
@ -145,7 +144,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return core::VisionTaskApiFactory::Create<ImageClassifier,
|
return core::VisionTaskApiFactory::Create<ImageClassifier,
|
||||||
ImageClassifierOptionsProto>(
|
proto::ImageClassifierGraphOptions>(
|
||||||
CreateGraphConfig(
|
CreateGraphConfig(
|
||||||
std::move(options_proto),
|
std::move(options_proto),
|
||||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||||
|
@ -214,6 +213,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace image_classifier
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
namespace image_classifier {
|
||||||
|
|
||||||
// The options for configuring a Mediapipe image classifier task.
|
// The options for configuring a Mediapipe image classifier task.
|
||||||
struct ImageClassifierOptions {
|
struct ImageClassifierOptions {
|
||||||
|
@ -161,6 +162,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace image_classifier
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -29,11 +29,12 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
namespace image_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -42,8 +43,6 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::GenericNode;
|
using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ImageClassifierOptionsProto =
|
|
||||||
image_classifier::proto::ImageClassifierOptions;
|
|
||||||
|
|
||||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||||
|
|
||||||
|
@ -61,8 +60,7 @@ struct ImageClassifierOutputStreams {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image
|
// An "ImageClassifierGraph" performs image classification.
|
||||||
// classification.
|
|
||||||
// - Accepts CPU input images and outputs classifications on CPU.
|
// - Accepts CPU input images and outputs classifications on CPU.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
|
@ -80,12 +78,12 @@ struct ImageClassifierOutputStreams {
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.vision.ImageClassifierGraph"
|
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
|
||||||
// input_stream: "IMAGE:image_in"
|
// input_stream: "IMAGE:image_in"
|
||||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||||
// output_stream: "IMAGE:image_out"
|
// output_stream: "IMAGE:image_out"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext]
|
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
|
||||||
// {
|
// {
|
||||||
// base_options {
|
// base_options {
|
||||||
// model_asset {
|
// model_asset {
|
||||||
|
@ -104,13 +102,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
ASSIGN_OR_RETURN(
|
||||||
CreateModelResources<ImageClassifierOptionsProto>(sc));
|
const auto* model_resources,
|
||||||
|
CreateModelResources<proto::ImageClassifierGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_streams,
|
auto output_streams,
|
||||||
BuildImageClassificationTask(
|
BuildImageClassificationTask(
|
||||||
sc->Options<ImageClassifierOptionsProto>(), *model_resources,
|
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
output_streams.classification_result >>
|
output_streams.classification_result >>
|
||||||
|
@ -125,13 +124,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
// (mediapipe::Image) as input and returns one classification result per input
|
// (mediapipe::Image) as input and returns one classification result per input
|
||||||
// image.
|
// image.
|
||||||
//
|
//
|
||||||
// task_options: the mediapipe tasks ImageClassifierOptions.
|
// task_options: the mediapipe tasks ImageClassifierGraphOptions.
|
||||||
// model_resources: the ModelSources object initialized from an image
|
// model_resources: the ModelSources object initialized from an image
|
||||||
// classification model file with model metadata.
|
// classification model file with model metadata.
|
||||||
// image_in: (mediapipe::Image) stream to run classification on.
|
// image_in: (mediapipe::Image) stream to run classification on.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask(
|
absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask(
|
||||||
const ImageClassifierOptionsProto& task_options,
|
const proto::ImageClassifierGraphOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||||
// Adds preprocessing calculators and connects them to the graph input image
|
// Adds preprocessing calculators and connects them to the graph input image
|
||||||
|
@ -168,8 +167,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph);
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph);
|
||||||
|
|
||||||
|
} // namespace image_classifier
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
namespace image_classifier {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
@ -814,6 +815,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace image_classifier
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "image_classifier_options_proto",
|
name = "image_classifier_graph_options_proto",
|
||||||
srcs = ["image_classifier_options.proto"],
|
srcs = ["image_classifier_graph_options.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
|
|
@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message ImageClassifierOptions {
|
message ImageClassifierGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ImageClassifierOptions ext = 456383383;
|
optional ImageClassifierGraphOptions ext = 456383383;
|
||||||
}
|
}
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, etc.
|
// model file with metadata, accelerator options, etc.
|
Loading…
Reference in New Issue
Block a user