From a8ca669f050103fa4bcd7b6090d5e83b29394d4e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 29 Sep 2022 06:03:59 -0700 Subject: [PATCH] Refactor classifiers public APIs for namespace and naming consistency. PiperOrigin-RevId: 477705647 --- .../tasks/cc/audio/audio_classifier/BUILD | 4 +-- .../audio_classifier/audio_classifier.cc | 25 ++++++------- .../audio/audio_classifier/audio_classifier.h | 2 ++ .../audio_classifier_graph.cc | 36 ++++++++++--------- .../audio_classifier/audio_classifier_test.cc | 2 ++ .../cc/audio/audio_classifier/proto/BUILD | 4 +-- ...o => audio_classifier_graph_options.proto} | 4 +-- .../tasks/cc/vision/image_classifier/BUILD | 4 +-- .../image_classifier/image_classifier.cc | 26 +++++++------- .../image_classifier/image_classifier.h | 2 ++ .../image_classifier_graph.cc | 27 +++++++------- .../image_classifier/image_classifier_test.cc | 2 ++ .../cc/vision/image_classifier/proto/BUILD | 4 +-- ...o => image_classifier_graph_options.proto} | 4 +-- 14 files changed, 80 insertions(+), 66 deletions(-) rename mediapipe/tasks/cc/audio/audio_classifier/proto/{audio_classifier_options.proto => audio_classifier_graph_options.proto} (94%) rename mediapipe/tasks/cc/vision/image_classifier/proto/{image_classifier_options.proto => image_classifier_graph_options.proto} (93%) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 363dc89a9..20ccf68f0 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -33,7 +33,7 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:matrix", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", "//mediapipe/tasks/cc/components:classification_postprocessing", "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", @@ -60,7 +60,7 @@ cc_library( ":audio_classifier_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:running_mode", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index db939f341..9a8075f77 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" +#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" #include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" @@ -33,6 +33,8 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { + namespace { constexpr char kAudioStreamName[] = "audio_in"; @@ -42,16 +44,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.audio.AudioClassifierGraph"; + "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; -using AudioClassifierOptionsProto = - audio_classifier::proto::AudioClassifierOptions; - // Creates a MediaPipe graph config that only contains a single subgraph node of -// "mediapipe.tasks.audio.AudioClassifierGraph". +// type "AudioClassifierGraph". CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options_proto) { + std::unique_ptr options_proto) { api2::builder::Graph graph; auto& subgraph = graph.AddNode(kSubgraphTypeName); graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag); @@ -59,7 +58,8 @@ CalculatorGraphConfig CreateGraphConfig( graph.In(kSampleRateTag).SetName(kSampleRateName) >> subgraph.In(kSampleRateTag); } - subgraph.GetOptions().Swap(options_proto.get()); + subgraph.GetOptions().Swap( + options_proto.get()); subgraph.Out(kClassificationResultTag) .SetName(kClassificationResultStreamName) >> graph.Out(kClassificationResultTag); @@ -67,10 +67,10 @@ CalculatorGraphConfig CreateGraphConfig( } // Converts the user-facing AudioClassifierOptions struct to the internal -// AudioClassifierOptions proto. -std::unique_ptr +// AudioClassifierGraphOptions proto. +std::unique_ptr ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { - auto options_proto = std::make_unique(); + auto options_proto = std::make_unique(); auto base_options_proto = std::make_unique( tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); options_proto->mutable_base_options()->Swap(base_options_proto.get()); @@ -119,7 +119,7 @@ absl::StatusOr> AudioClassifier::Create( }; } return core::AudioTaskApiFactory::Create( + proto::AudioClassifierGraphOptions>( CreateGraphConfig(std::move(options_proto)), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)); @@ -140,6 +140,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index 688bb60e3..bd8bd5e0c 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -30,6 +30,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { // The options for configuring a mediapipe audio classifier task. struct AudioClassifierOptions { @@ -162,6 +163,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { absl::Status Close() { return runner_->Close(); } }; +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 0f40b59a4..810fb2da5 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -28,7 +28,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" +#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/classification_postprocessing.h" @@ -44,6 +44,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { namespace { @@ -60,10 +61,9 @@ constexpr char kPacketTag[] = "PACKET"; constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; -using AudioClassifierOptionsProto = - audio_classifier::proto::AudioClassifierOptions; -absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) { +absl::Status SanityCheckOptions( + const proto::AudioClassifierGraphOptions& options) { if (options.base_options().use_stream_mode() && !options.has_default_input_audio_sample_rate()) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, @@ -111,7 +111,7 @@ void ConfigureAudioToTensorCalculator( } // namespace -// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification. +// An "AudioClassifierGraph" performs audio classification. // - Accepts CPU audio buffer and outputs classification results on CPU. // // Inputs: @@ -129,12 +129,12 @@ void ConfigureAudioToTensorCalculator( // // Example: // node { -// calculator: "mediapipe.tasks.audio.AudioClassifierGraph" +// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph" // input_stream: "AUDIO:audio_in" // input_stream: "SAMPLE_RATE:sample_rate_in" // output_stream: "CLASSIFICATION_RESULT:classification_result_out" // options { -// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext] +// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext] // { // base_options { // model_asset { @@ -152,16 +152,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); Graph graph; - const bool use_stream_mode = sc->Options() - .base_options() - .use_stream_mode(); + const bool use_stream_mode = + sc->Options() + .base_options() + .use_stream_mode(); ASSIGN_OR_RETURN( auto classification_result_out, BuildAudioClassificationTask( - sc->Options(), *model_resources, + sc->Options(), *model_resources, graph[Input(kAudioTag)], use_stream_mode ? absl::nullopt @@ -178,14 +180,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // buffer (mediapipe::Matrix) and the corresponding sample rate (double) as // the inputs and returns one classification result per input audio buffer. // - // task_options: the mediapipe tasks AudioClassifierOptions proto. + // task_options: the mediapipe tasks AudioClassifierGraphOptions proto. // model_resources: the ModelSources object initialized from an audio // classifier model file with model metadata. // audio_in: (mediapipe::Matrix) stream to run audio classification on. // sample_rate_in: (double) optional stream of the input audio sample rate. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr> BuildAudioClassificationTask( - const AudioClassifierOptionsProto& task_options, + const proto::AudioClassifierGraphOptions& task_options, const core::ModelResources& model_resources, Source audio_in, absl::optional> sample_rate_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); @@ -257,8 +259,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph); +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index dd56c4ff1..4e874b520 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -44,6 +44,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { namespace { using ::absl::StatusOr; @@ -557,6 +558,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { } } // namespace +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD index 7b1952e06..033bb51ac 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD @@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "audio_classifier_options_proto", - srcs = ["audio_classifier_options.proto"], + name = "audio_classifier_graph_options_proto", + srcs = ["audio_classifier_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto similarity index 94% rename from mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto rename to mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index a76ccdcab..63b4b3293 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message AudioClassifierOptions { +message AudioClassifierGraphOptions { extend mediapipe.CalculatorOptions { - optional AudioClassifierOptions ext = 451755788; + optional AudioClassifierGraphOptions ext = 451755788; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index 4dcecdbbe..e7c8a6586 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -33,7 +33,7 @@ cc_library( "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", "@com_google_absl//absl/status:statusor", ], alwayslink = 1, @@ -61,7 +61,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index eb74c3d98..1e092e85a 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -36,11 +36,12 @@ limitations under the License. #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" -#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { namespace { @@ -52,12 +53,10 @@ constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectName[] = "norm_rect_in"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.vision.ImageClassifierGraph"; + "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::tasks::core::PacketMap; -using ImageClassifierOptionsProto = - image_classifier::proto::ImageClassifierOptions; // Builds a NormalizedRect covering the entire image. NormalizedRect BuildFullImageNormRect() { @@ -70,17 +69,17 @@ NormalizedRect BuildFullImageNormRect() { } // Creates a MediaPipe graph config that contains a subgraph node of -// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the -// live stream mode, a "FlowLimiterCalculator" will be added to limit the number -// of frames in flight. +// type "ImageClassifierGraph". If the task is running in the live stream mode, +// a "FlowLimiterCalculator" will be added to limit the number of frames in +// flight. CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options_proto, + std::unique_ptr options_proto, bool enable_flow_limiting) { api2::builder::Graph graph; graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName); - task_subgraph.GetOptions().Swap( + task_subgraph.GetOptions().Swap( options_proto.get()); task_subgraph.Out(kClassificationResultTag) .SetName(kClassificationResultStreamName) >> @@ -98,10 +97,10 @@ CalculatorGraphConfig CreateGraphConfig( } // Converts the user-facing ImageClassifierOptions struct to the internal -// ImageClassifierOptions proto. -std::unique_ptr +// ImageClassifierGraphOptions proto. +std::unique_ptr ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { - auto options_proto = std::make_unique(); + auto options_proto = std::make_unique(); auto base_options_proto = std::make_unique( tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); options_proto->mutable_base_options()->Swap(base_options_proto.get()); @@ -145,7 +144,7 @@ absl::StatusOr> ImageClassifier::Create( }; } return core::VisionTaskApiFactory::Create( + proto::ImageClassifierGraphOptions>( CreateGraphConfig( std::move(options_proto), options->running_mode == core::RunningMode::LIVE_STREAM), @@ -214,6 +213,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 2fbac71b2..8ff11413e 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -32,6 +32,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { // The options for configuring a Mediapipe image classifier task. struct ImageClassifierOptions { @@ -161,6 +162,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { absl::Status Close() { return runner_->Close(); } }; +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 532b7db45..0d7b60c99 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -29,11 +29,12 @@ limitations under the License. #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { namespace { @@ -42,8 +43,6 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ImageClassifierOptionsProto = - image_classifier::proto::ImageClassifierOptions; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -61,8 +60,7 @@ struct ImageClassifierOutputStreams { } // namespace -// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image -// classification. +// An "ImageClassifierGraph" performs image classification. // - Accepts CPU input images and outputs classifications on CPU. // // Inputs: @@ -80,12 +78,12 @@ struct ImageClassifierOutputStreams { // // Example: // node { -// calculator: "mediapipe.tasks.vision.ImageClassifierGraph" +// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph" // input_stream: "IMAGE:image_in" // output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "IMAGE:image_out" // options { -// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext] +// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext] // { // base_options { // model_asset { @@ -104,13 +102,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( auto output_streams, BuildImageClassificationTask( - sc->Options(), *model_resources, + sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); output_streams.classification_result >> @@ -125,13 +124,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // (mediapipe::Image) as input and returns one classification result per input // image. // - // task_options: the mediapipe tasks ImageClassifierOptions. + // task_options: the mediapipe tasks ImageClassifierGraphOptions. // model_resources: the ModelSources object initialized from an image // classification model file with model metadata. // image_in: (mediapipe::Image) stream to run classification on. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr BuildImageClassificationTask( - const ImageClassifierOptionsProto& task_options, + const proto::ImageClassifierGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image @@ -168,8 +167,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph { /*image=*/preprocessing[Output(kImageTag)]}; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph); +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 7cf6414bf..edbb851c0 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -44,6 +44,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; @@ -814,6 +815,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { } } // namespace +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD index dc8241799..a6f5791e3 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD @@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "image_classifier_options_proto", - srcs = ["image_classifier_options.proto"], + name = "image_classifier_graph_options_proto", + srcs = ["image_classifier_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto similarity index 93% rename from mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto rename to mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 8aa8b4615..3da047110 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message ImageClassifierOptions { +message ImageClassifierGraphOptions { extend mediapipe.CalculatorOptions { - optional ImageClassifierOptions ext = 456383383; + optional ImageClassifierGraphOptions ext = 456383383; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc.