Refactor classifiers public APIs for namespace and naming consistency.

PiperOrigin-RevId: 477705647
This commit is contained in:
MediaPipe Team 2022-09-29 06:03:59 -07:00 committed by Copybara-Service
parent 554e2a9d69
commit a8ca669f05
14 changed files with 80 additions and 66 deletions

View File

@ -33,7 +33,7 @@ cc_library(
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
"//mediapipe/tasks/cc/components:classification_postprocessing", "//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
@ -60,7 +60,7 @@ cc_library(
":audio_classifier_graph", ":audio_classifier_graph",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode", "//mediapipe/tasks/cc/audio/core:running_mode",

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/classifier_options.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
@ -33,6 +33,8 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
namespace { namespace {
constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioStreamName[] = "audio_in";
@ -42,16 +44,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.audio.AudioClassifierGraph"; "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using AudioClassifierOptionsProto =
audio_classifier::proto::AudioClassifierOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of // Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.audio.AudioClassifierGraph". // type "AudioClassifierGraph".
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<AudioClassifierOptionsProto> options_proto) { std::unique_ptr<proto::AudioClassifierGraphOptions> options_proto) {
api2::builder::Graph graph; api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kSubgraphTypeName); auto& subgraph = graph.AddNode(kSubgraphTypeName);
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag); graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
@ -59,7 +58,8 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kSampleRateTag).SetName(kSampleRateName) >> graph.In(kSampleRateTag).SetName(kSampleRateName) >>
subgraph.In(kSampleRateTag); subgraph.In(kSampleRateTag);
} }
subgraph.GetOptions<AudioClassifierOptionsProto>().Swap(options_proto.get()); subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get());
subgraph.Out(kClassificationResultTag) subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >> .SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag); graph.Out(kClassificationResultTag);
@ -67,10 +67,10 @@ CalculatorGraphConfig CreateGraphConfig(
} }
// Converts the user-facing AudioClassifierOptions struct to the internal // Converts the user-facing AudioClassifierOptions struct to the internal
// AudioClassifierOptions proto. // AudioClassifierGraphOptions proto.
std::unique_ptr<AudioClassifierOptionsProto> std::unique_ptr<proto::AudioClassifierGraphOptions>
ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
auto options_proto = std::make_unique<AudioClassifierOptionsProto>(); auto options_proto = std::make_unique<proto::AudioClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>( auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->Swap(base_options_proto.get());
@ -119,7 +119,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
}; };
} }
return core::AudioTaskApiFactory::Create<AudioClassifier, return core::AudioTaskApiFactory::Create<AudioClassifier,
AudioClassifierOptionsProto>( proto::AudioClassifierGraphOptions>(
CreateGraphConfig(std::move(options_proto)), CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver), options->running_mode, std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback)); std::move(packets_callback));
@ -140,6 +140,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -30,6 +30,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
// The options for configuring a mediapipe audio classifier task. // The options for configuring a mediapipe audio classifier task.
struct AudioClassifierOptions { struct AudioClassifierOptions {
@ -162,6 +163,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
}; };
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/classification_postprocessing.h"
@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
namespace { namespace {
@ -60,10 +61,9 @@ constexpr char kPacketTag[] = "PACKET";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
using AudioClassifierOptionsProto =
audio_classifier::proto::AudioClassifierOptions;
absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) { absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() && if (options.base_options().use_stream_mode() &&
!options.has_default_input_audio_sample_rate()) { !options.has_default_input_audio_sample_rate()) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
@ -111,7 +111,7 @@ void ConfigureAudioToTensorCalculator(
} // namespace } // namespace
// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification. // An "AudioClassifierGraph" performs audio classification.
// - Accepts CPU audio buffer and outputs classification results on CPU. // - Accepts CPU audio buffer and outputs classification results on CPU.
// //
// Inputs: // Inputs:
@ -129,12 +129,12 @@ void ConfigureAudioToTensorCalculator(
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.audio.AudioClassifierGraph" // calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
// input_stream: "AUDIO:audio_in" // input_stream: "AUDIO:audio_in"
// input_stream: "SAMPLE_RATE:sample_rate_in" // input_stream: "SAMPLE_RATE:sample_rate_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// options { // options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext] // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
// { // {
// base_options { // base_options {
// model_asset { // model_asset {
@ -152,16 +152,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(
CreateModelResources<AudioClassifierOptionsProto>(sc)); const auto* model_resources,
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
Graph graph; Graph graph;
const bool use_stream_mode = sc->Options<AudioClassifierOptionsProto>() const bool use_stream_mode =
.base_options() sc->Options<proto::AudioClassifierGraphOptions>()
.use_stream_mode(); .base_options()
.use_stream_mode();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto classification_result_out,
BuildAudioClassificationTask( BuildAudioClassificationTask(
sc->Options<AudioClassifierOptionsProto>(), *model_resources, sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)], graph[Input<Matrix>(kAudioTag)],
use_stream_mode use_stream_mode
? absl::nullopt ? absl::nullopt
@ -178,14 +180,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// buffer (mediapipe::Matrix) and the corresponding sample rate (double) as // buffer (mediapipe::Matrix) and the corresponding sample rate (double) as
// the inputs and returns one classification result per input audio buffer. // the inputs and returns one classification result per input audio buffer.
// //
// task_options: the mediapipe tasks AudioClassifierOptions proto. // task_options: the mediapipe tasks AudioClassifierGraphOptions proto.
// model_resources: the ModelSources object initialized from an audio // model_resources: the ModelSources object initialized from an audio
// classifier model file with model metadata. // classifier model file with model metadata.
// audio_in: (mediapipe::Matrix) stream to run audio classification on. // audio_in: (mediapipe::Matrix) stream to run audio classification on.
// sample_rate_in: (double) optional stream of the input audio sample rate. // sample_rate_in: (double) optional stream of the input audio sample rate.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask( absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
const AudioClassifierOptionsProto& task_options, const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in, const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) { absl::optional<Source<double>> sample_rate_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
@ -257,8 +259,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
} }
}; };
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph); REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph);
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
namespace { namespace {
using ::absl::StatusOr; using ::absl::StatusOr;
@ -557,6 +558,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
} }
} // namespace } // namespace
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "audio_classifier_options_proto", name = "audio_classifier_graph_options_proto",
srcs = ["audio_classifier_options.proto"], srcs = ["audio_classifier_graph_options.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",

View File

@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message AudioClassifierOptions { message AudioClassifierGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional AudioClassifierOptions ext = 451755788; optional AudioClassifierGraphOptions ext = 451755788;
} }
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite // Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc. // model file with metadata, accelerator options, etc.

View File

@ -33,7 +33,7 @@ cc_library(
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
alwayslink = 1, alwayslink = 1,
@ -61,7 +61,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],

View File

@ -36,11 +36,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
namespace { namespace {
@ -52,12 +53,10 @@ constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectName[] = "norm_rect_in"; constexpr char kNormRectName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageClassifierGraph"; "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
using ImageClassifierOptionsProto =
image_classifier::proto::ImageClassifierOptions;
// Builds a NormalizedRect covering the entire image. // Builds a NormalizedRect covering the entire image.
NormalizedRect BuildFullImageNormRect() { NormalizedRect BuildFullImageNormRect() {
@ -70,17 +69,17 @@ NormalizedRect BuildFullImageNormRect() {
} }
// Creates a MediaPipe graph config that contains a subgraph node of // Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the // type "ImageClassifierGraph". If the task is running in the live stream mode,
// live stream mode, a "FlowLimiterCalculator" will be added to limit the number // a "FlowLimiterCalculator" will be added to limit the number of frames in
// of frames in flight. // flight.
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageClassifierOptionsProto> options_proto, std::unique_ptr<proto::ImageClassifierGraphOptions> options_proto,
bool enable_flow_limiting) { bool enable_flow_limiting) {
api2::builder::Graph graph; api2::builder::Graph graph;
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectName); graph.In(kNormRectTag).SetName(kNormRectName);
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageClassifierOptionsProto>().Swap( task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
options_proto.get()); options_proto.get());
task_subgraph.Out(kClassificationResultTag) task_subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >> .SetName(kClassificationResultStreamName) >>
@ -98,10 +97,10 @@ CalculatorGraphConfig CreateGraphConfig(
} }
// Converts the user-facing ImageClassifierOptions struct to the internal // Converts the user-facing ImageClassifierOptions struct to the internal
// ImageClassifierOptions proto. // ImageClassifierGraphOptions proto.
std::unique_ptr<ImageClassifierOptionsProto> std::unique_ptr<proto::ImageClassifierGraphOptions>
ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
auto options_proto = std::make_unique<ImageClassifierOptionsProto>(); auto options_proto = std::make_unique<proto::ImageClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>( auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->Swap(base_options_proto.get());
@ -145,7 +144,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
}; };
} }
return core::VisionTaskApiFactory::Create<ImageClassifier, return core::VisionTaskApiFactory::Create<ImageClassifier,
ImageClassifierOptionsProto>( proto::ImageClassifierGraphOptions>(
CreateGraphConfig( CreateGraphConfig(
std::move(options_proto), std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM), options->running_mode == core::RunningMode::LIVE_STREAM),
@ -214,6 +213,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -32,6 +32,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
// The options for configuring a Mediapipe image classifier task. // The options for configuring a Mediapipe image classifier task.
struct ImageClassifierOptions { struct ImageClassifierOptions {
@ -161,6 +162,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
}; };
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -29,11 +29,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
namespace { namespace {
@ -42,8 +43,6 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ImageClassifierOptionsProto =
image_classifier::proto::ImageClassifierOptions;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
@ -61,8 +60,7 @@ struct ImageClassifierOutputStreams {
} // namespace } // namespace
// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image // An "ImageClassifierGraph" performs image classification.
// classification.
// - Accepts CPU input images and outputs classifications on CPU. // - Accepts CPU input images and outputs classifications on CPU.
// //
// Inputs: // Inputs:
@ -80,12 +78,12 @@ struct ImageClassifierOutputStreams {
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.vision.ImageClassifierGraph" // calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// output_stream: "IMAGE:image_out" // output_stream: "IMAGE:image_out"
// options { // options {
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext] // [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
// { // {
// base_options { // base_options {
// model_asset { // model_asset {
@ -104,13 +102,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(
CreateModelResources<ImageClassifierOptionsProto>(sc)); const auto* model_resources,
CreateModelResources<proto::ImageClassifierGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_streams, auto output_streams,
BuildImageClassificationTask( BuildImageClassificationTask(
sc->Options<ImageClassifierOptionsProto>(), *model_resources, sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.classification_result >> output_streams.classification_result >>
@ -125,13 +124,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// (mediapipe::Image) as input and returns one classification result per input // (mediapipe::Image) as input and returns one classification result per input
// image. // image.
// //
// task_options: the mediapipe tasks ImageClassifierOptions. // task_options: the mediapipe tasks ImageClassifierGraphOptions.
// model_resources: the ModelSources object initialized from an image // model_resources: the ModelSources object initialized from an image
// classification model file with model metadata. // classification model file with model metadata.
// image_in: (mediapipe::Image) stream to run classification on. // image_in: (mediapipe::Image) stream to run classification on.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask( absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask(
const ImageClassifierOptionsProto& task_options, const proto::ImageClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
@ -168,8 +167,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
/*image=*/preprocessing[Output<Image>(kImageTag)]}; /*image=*/preprocessing[Output<Image>(kImageTag)]};
} }
}; };
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph); REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph);
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
@ -814,6 +815,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
} }
} // namespace } // namespace
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "image_classifier_options_proto", name = "image_classifier_graph_options_proto",
srcs = ["image_classifier_options.proto"], srcs = ["image_classifier_graph_options.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",

View File

@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageClassifierOptions { message ImageClassifierGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ImageClassifierOptions ext = 456383383; optional ImageClassifierGraphOptions ext = 456383383;
} }
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite // Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc. // model file with metadata, accelerator options, etc.