From 03c8ac3641a84a2dd03167ee23f99942d09ea40e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Oct 2022 01:58:41 -0700 Subject: [PATCH] Refactor ClassificationResult and ClassificationPostprocessing. PiperOrigin-RevId: 478444264 --- .../tasks/cc/audio/audio_classifier/BUILD | 12 +- .../audio_classifier/audio_classifier.cc | 11 +- .../audio/audio_classifier/audio_classifier.h | 15 ++- .../audio_classifier_graph.cc | 20 +-- .../audio_classifier/audio_classifier_test.cc | 5 +- .../cc/audio/audio_classifier/proto/BUILD | 2 +- .../audio_classifier_graph_options.proto | 4 +- mediapipe/tasks/cc/components/BUILD | 59 --------- .../tasks/cc/components/calculators/BUILD | 6 +- .../classification_aggregation_calculator.cc | 8 +- .../calculators/end_loop_calculator.cc | 5 +- .../cc/components/containers/proto/BUILD | 23 +++- .../containers/{ => proto}/category.proto | 2 +- .../{ => proto}/classifications.proto | 4 +- .../tasks/cc/components/processors/BUILD | 64 ++++++++++ .../classification_postprocessing_graph.cc} | 54 ++++---- .../classification_postprocessing_graph.h} | 28 ++-- ...assification_postprocessing_graph_test.cc} | 120 +++++++++--------- .../{ => processors}/classifier_options.cc | 10 +- .../{ => processors}/classifier_options.h | 12 +- .../{containers => processors/proto}/BUILD | 14 +- ...cation_postprocessing_graph_options.proto} | 6 +- .../proto/classifier_options.proto | 2 +- mediapipe/tasks/cc/components/proto/BUILD | 5 - .../cc/vision/hand_gesture_recognizer/BUILD | 6 +- .../hand_gesture_recognizer_subgraph.cc | 20 +-- .../hand_gesture_recognizer/proto/BUILD | 4 +- ..._gesture_recognizer_subgraph_options.proto | 4 +- .../tasks/cc/vision/image_classifier/BUILD | 12 +- .../image_classifier/image_classifier.cc | 11 +- .../image_classifier/image_classifier.h | 18 +-- .../image_classifier_graph.cc | 20 +-- .../image_classifier/image_classifier_test.cc | 7 +- .../cc/vision/image_classifier/proto/BUILD | 2 +- .../image_classifier_graph_options.proto | 4 +- .../tasks/python/components/containers/BUILD | 2 +- .../python/components/containers/category.py | 2 +- 37 files changed, 329 insertions(+), 274 deletions(-) rename mediapipe/tasks/cc/components/containers/{ => proto}/category.proto (96%) rename mediapipe/tasks/cc/components/containers/{ => proto}/classifications.proto (93%) create mode 100644 mediapipe/tasks/cc/components/processors/BUILD rename mediapipe/tasks/cc/components/{classification_postprocessing.cc => processors/classification_postprocessing_graph.cc} (92%) rename mediapipe/tasks/cc/components/{classification_postprocessing.h => processors/classification_postprocessing_graph.h} (59%) rename mediapipe/tasks/cc/components/{classification_postprocessing_test.cc => processors/classification_postprocessing_graph_test.cc} (88%) rename mediapipe/tasks/cc/components/{ => processors}/classifier_options.cc (81%) rename mediapipe/tasks/cc/components/{ => processors}/classifier_options.h (83%) rename mediapipe/tasks/cc/components/{containers => processors/proto}/BUILD (58%) rename mediapipe/tasks/cc/components/{classification_postprocessing_options.proto => processors/proto/classification_postprocessing_graph_options.proto} (91%) rename mediapipe/tasks/cc/components/{ => processors}/proto/classifier_options.proto (97%) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 20ccf68f0..ac238bfda 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -35,9 +35,10 @@ cc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", @@ -64,8 +65,9 @@ cc_library( "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index 9a8075f77..702d802c5 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -24,8 +24,9 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -37,6 +38,8 @@ namespace audio_classifier { namespace { +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; + constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioTag[] = "AUDIO"; constexpr char kClassificationResultStreamName[] = "classification_result_out"; @@ -77,8 +80,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode == core::RunningMode::AUDIO_STREAM); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index bd8bd5e0c..200cffb8c 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" namespace mediapipe { @@ -40,7 +40,7 @@ struct AudioClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The running mode of the audio classifier. Default to the audio clips mode. // Audio classifier has two running modes: @@ -59,8 +59,9 @@ struct AudioClassifierOptions { // The user-defined result callback for processing audio stream data. // The result callback should only be specified when the running mode is set // to RunningMode::AUDIO_STREAM. - std::function)> result_callback = - nullptr; + std::function)> + result_callback = nullptr; }; // Performs audio classification on audio clips or audio stream. @@ -132,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { // framed audio clip. // TODO: Use `sample_rate` in AudioClassifierOptions by default // and makes `audio_sample_rate` optional. - absl::StatusOr Classify(mediapipe::Matrix audio_clip, - double audio_sample_rate); + absl::StatusOr Classify( + mediapipe::Matrix audio_clip, double audio_sample_rate); // Sends audio data (a block in a continuous audio stream) to perform audio // classification. Only use this method when the AudioClassifier is created diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 810fb2da5..12f8ce31a 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -31,9 +31,9 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -53,6 +53,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAudioTag[] = "AUDIO"; @@ -238,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio classification on diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 4e874b520..4b64d2231 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" namespace mediapipe { @@ -49,6 +49,7 @@ namespace { using ::absl::StatusOr; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD index 033bb51ac..bfe37ec01 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 63b4b3293..16aa86aeb 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message AudioClassifierGraphOptions { @@ -31,7 +31,7 @@ message AudioClassifierGraphOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; // The default sample rate of the input audio. Must be set when the // AudioClassifier is configured to process audio stream data. diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 4de32ce9b..7939e4e39 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -58,65 +58,6 @@ cc_library( # TODO: Enable this test -cc_library( - name = "classifier_options", - srcs = ["classifier_options.cc"], - hdrs = ["classifier_options.h"], - deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"], -) - -mediapipe_proto_library( - name = "classification_postprocessing_options_proto", - srcs = ["classification_postprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", - ], -) - -cc_library( - name = "classification_postprocessing", - srcs = ["classification_postprocessing.cc"], - hdrs = ["classification_postprocessing.h"], - deps = [ - ":classification_postprocessing_options_cc_proto", - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/tensor:tensors_dequantization_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:packet", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 13ca6b496..7d01e4dfe 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,8 +37,8 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers:category_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], alwayslink = 1, @@ -128,7 +128,7 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index b2848bc3f..e1f69e607 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,15 +25,15 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions; -using ::mediapipe::tasks::ClassificationResult; -using ::mediapipe::tasks::Classifications; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into a single ClassificationResult that has // 3 dimensions: (classification head, classification timestamp, classification diff --git a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc index b688cda91..10eb962dd 100644 --- a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc @@ -17,12 +17,13 @@ limitations under the License. #include -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" // Specialized EndLoopCalculator for Tasks specific types. namespace mediapipe::tasks { -typedef EndLoopCalculator> +typedef EndLoopCalculator< + std::vector> EndLoopClassificationResultCalculator; REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 9c6402e64..633b5b369 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "category_proto", + srcs = ["category.proto"], +) + +mediapipe_proto_library( + name = "classifications_proto", + srcs = ["classifications.proto"], + deps = [ + ":category_proto", + ], +) + +mediapipe_proto_library( + name = "embeddings_proto", + srcs = ["embeddings.proto"], +) + mediapipe_proto_library( name = "landmarks_detection_result_proto", srcs = [ @@ -29,8 +47,3 @@ mediapipe_proto_library( "//mediapipe/framework/formats:rect_proto", ], ) - -mediapipe_proto_library( - name = "embeddings_proto", - srcs = ["embeddings.proto"], -) diff --git a/mediapipe/tasks/cc/components/containers/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto similarity index 96% rename from mediapipe/tasks/cc/components/containers/category.proto rename to mediapipe/tasks/cc/components/containers/proto/category.proto index 47f38b75a..a44fb5b15 100644 --- a/mediapipe/tasks/cc/components/containers/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; // A single classification result. message Category { diff --git a/mediapipe/tasks/cc/components/containers/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto similarity index 93% rename from mediapipe/tasks/cc/components/containers/classifications.proto rename to mediapipe/tasks/cc/components/containers/proto/classifications.proto index 469c67fc9..e0ccad7a1 100644 --- a/mediapipe/tasks/cc/components/containers/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; -import "mediapipe/tasks/cc/components/containers/category.proto"; +import "mediapipe/tasks/cc/components/containers/proto/category.proto"; // List of predicted categories with an optional timestamp. message ClassificationEntry { diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD new file mode 100644 index 000000000..62f04dcb7 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -0,0 +1,64 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "classifier_options", + srcs = ["classifier_options.cc"], + hdrs = ["classifier_options.h"], + deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], +) + +cc_library( + name = "classification_postprocessing_graph", + srcs = ["classification_postprocessing_graph.cc"], + hdrs = ["classification_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:tensors_dequantization_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc similarity index 92% rename from mediapipe/tasks/cc/components/classification_postprocessing.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 871476e8f..35adab687 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include @@ -37,9 +37,9 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" @@ -51,6 +51,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; @@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; // Performs sanity checks on provided ClassifierOptions. -absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) { +absl::Status SanityCheckClassifierOptions( + const proto::ClassifierOptions& options) { if (options.max_results() == 0) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -203,7 +205,7 @@ absl::StatusOr GetScoreThreshold( // Gets the category allowlist or denylist (if any) as a set of indices. absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( - const ClassifierOptions& options, const LabelItems& label_items) { + const proto::ClassifierOptions& options, const LabelItems& label_items) { absl::flat_hash_set category_indices; // Exit early if no denylist/allowlist. if (options.category_denylist_size() == 0 && @@ -239,7 +241,7 @@ absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( absl::Status ConfigureScoreCalibrationIfAny( const ModelMetadataExtractor& metadata_extractor, int tensor_index, - ClassificationPostprocessingOptions* options) { + proto::ClassificationPostprocessingGraphOptions* options) { const auto* tensor_metadata = metadata_extractor.GetOutputTensorMetadata(tensor_index); if (tensor_metadata == nullptr) { @@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny( // Fills in the TensorsToClassificationCalculatorOptions based on the // classifier options and the (optional) output tensor metadata. absl::Status ConfigureTensorsToClassificationCalculator( - const ClassifierOptions& options, + const proto::ClassifierOptions& options, const ModelMetadataExtractor& metadata_extractor, int tensor_index, TensorsToClassificationCalculatorOptions* calculator_options) { const auto* tensor_metadata = @@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator( } // namespace -absl::Status ConfigureClassificationPostprocessing( +absl::Status ConfigureClassificationPostprocessingGraph( const ModelResources& model_resources, - const ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options) { + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options) { MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options)); ASSIGN_OR_RETURN(const auto heads_properties, GetClassificationHeadsProperties(model_resources)); @@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing( return absl::OkStatus(); } -// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts -// raw tensors into ClassificationResult objects. +// A "ClassificationPostprocessingGraph" converts raw tensors into +// ClassificationResult objects. // - Accepts CPU input tensors. // // Inputs: @@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing( // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. // -// The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureClassificationPostprocessing()' function. See header file -// for more details. -class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { +// The recommended way of using this graph is through the GraphBuilder API +// using the 'ConfigureClassificationPostprocessingGraph()' function. See header +// file for more details. +class ClassificationPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { ASSIGN_OR_RETURN( auto classification_result_out, BuildClassificationPostprocessing( - sc->Options(), + sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); classification_result_out >> @@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { } private: - // Adds an on-device classification postprocessing subgraph into the provided - // builder::Graph instance. The classification postprocessing subgraph takes + // Adds an on-device classification postprocessing graph into the provided + // builder::Graph instance. The classification postprocessing graph takes // tensors (std::vector) as input and returns one output // stream containing the output classification results (ClassificationResult). // - // options: the on-device ClassificationPostprocessingOptions. + // options: the on-device ClassificationPostprocessingGraphOptions. // tensors_in: (std::vector>) tensors to postprocess. // timestamps_in: (std::vector) optional collection of // timestamps that a single ClassificationResult should aggregate. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr> BuildClassificationPostprocessing( - const ClassificationPostprocessingOptions& options, + const proto::ClassificationPostprocessingGraphOptions& options, Source> tensors_in, Source> timestamps_in, Graph& graph) { const int num_heads = options.tensors_to_classifications_options_size(); @@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { kClassificationResultTag)]; } }; -REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ClassificationPostprocessingSubgraph); +REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors:: + ClassificationPostprocessingGraph); // NOLINT + +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h similarity index 59% rename from mediapipe/tasks/cc/components/classification_postprocessing.h rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index eb638bd60..8aedad46d 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -13,32 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures a ClassificationPostprocessing subgraph using the provided model +// Configures a ClassificationPostprocessingGraph using the provided model // resources and ClassifierOptions. // - Accepts CPU input tensors. // // Example usage: // // auto& postprocessing = -// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); -// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( +// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph"); +// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( // model_resources, // classifier_options, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ClassificationPostprocessing subgraph has the following I/O: +// The resulting ClassificationPostprocessingGraph has the following I/O: // Inputs: // TENSORS - std::vector // The output tensors of an InferenceCalculator. @@ -49,13 +50,14 @@ namespace components { // Outputs: // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. -absl::Status ConfigureClassificationPostprocessing( +absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, - const tasks::components::proto::ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options); + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc similarity index 88% rename from mediapipe/tasks/cc/components/classification_postprocessing_test.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index 67223050f..bb03e2530 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include #include @@ -42,9 +42,9 @@ limitations under the License. #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/util/label_map.pb.h" @@ -53,6 +53,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::api2::Input; @@ -60,7 +61,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::testing::HasSubstr; using ::testing::proto::Approximately; @@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(0); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); @@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); options_in.add_category_denylist("bar"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); @@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( @@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(3); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_score_threshold(0.5); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("tench"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_denylist("background"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { auto model_resources, CreateModelResourcesForModel( kQuantizedImageClassifierWithDummyScoreCalibration)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label maps sizes and first two elements. EXPECT_EQ( options_out.tensors_to_classifications_options(0).label_items_size(), @@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { class PostprocessingTest : public tflite_shims::testing::Test { protected: absl::StatusOr BuildGraph( - absl::string_view model_name, const ClassifierOptions& options, + absl::string_view model_name, const proto::ClassifierOptions& options, bool connect_timestamps = false) { ASSIGN_OR_RETURN(auto model_resources, CreateModelResourcesForModel(model_name)); Graph graph; auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( *model_resources, options, - &postprocessing.GetOptions())); + &postprocessing + .GetOptions())); graph[Input>(kTensorsTag)].SetName(kTensorsName) >> postprocessing.In(kTensorsTag); if (connect_timestamps) { @@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); options.set_score_threshold(0.5); MP_ASSERT_OK_AND_ASSIGN( @@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { TEST_F(PostprocessingTest, SucceedsWithMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); @@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { TEST_F(PostprocessingTest, SucceedsWithTimestamps) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, @@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { } } // namespace +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.cc b/mediapipe/tasks/cc/components/processors/classifier_options.cc similarity index 81% rename from mediapipe/tasks/cc/components/classifier_options.cc rename to mediapipe/tasks/cc/components/processors/classifier_options.cc index c54db5f88..349bb569d 100644 --- a/mediapipe/tasks/cc/components/classifier_options.cc +++ b/mediapipe/tasks/cc/components/processors/classifier_options.cc @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* options) { - tasks::components::proto::ClassifierOptions options_proto; + proto::ClassifierOptions options_proto; options_proto.set_display_names_locale(options->display_names_locale); options_proto.set_max_results(options->max_results); options_proto.set_score_threshold(options->score_threshold); @@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( return options_proto; } +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.h b/mediapipe/tasks/cc/components/processors/classifier_options.h similarity index 83% rename from mediapipe/tasks/cc/components/classifier_options.h rename to mediapipe/tasks/cc/components/processors/classifier_options.h index e15bf5e69..189b42e60 100644 --- a/mediapipe/tasks/cc/components/classifier_options.h +++ b/mediapipe/tasks/cc/components/processors/classifier_options.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { // Classifier options for MediaPipe C++ classification Tasks. struct ClassifierOptions { @@ -49,11 +50,12 @@ struct ClassifierOptions { }; // Converts a ClassifierOptions to a ClassifierOptionsProto. -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* classifier_options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD similarity index 58% rename from mediapipe/tasks/cc/components/containers/BUILD rename to mediapipe/tasks/cc/components/processors/proto/BUILD index 701f84824..d7cbe47ff 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], + name = "classifier_options_proto", + srcs = ["classifier_options.proto"], ) mediapipe_proto_library( - name = "classifications_proto", - srcs = ["classifications.proto"], + name = "classification_postprocessing_graph_options_proto", + srcs = ["classification_postprocessing_graph_options.proto"], deps = [ - ":category_proto", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", ], ) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto similarity index 91% rename from mediapipe/tasks/cc/components/classification_postprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto index 9b67e2f75..1de788eab 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto @@ -15,16 +15,16 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; -message ClassificationPostprocessingOptions { +message ClassificationPostprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ClassificationPostprocessingOptions ext = 460416950; + optional ClassificationPostprocessingGraphOptions ext = 460416950; } // Optional mapping between output tensor index and corresponding score diff --git a/mediapipe/tasks/cc/components/proto/classifier_options.proto b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto similarity index 97% rename from mediapipe/tasks/cc/components/proto/classifier_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classifier_options.proto index ea1491bb8..7afbfc14e 100644 --- a/mediapipe/tasks/cc/components/proto/classifier_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; // Shared options used by all classification tasks. message ClassifierOptions { diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 8c4dcdad9..c11d6f95a 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -23,11 +23,6 @@ mediapipe_proto_library( srcs = ["segmenter_options.proto"], ) -mediapipe_proto_library( - name = "classifier_options_proto", - srcs = ["classifier_options.proto"], -) - mediapipe_proto_library( name = "embedder_options_proto", srcs = ["embedder_options.proto"], diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD index bb5b86212..9e2d9bd17 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD @@ -54,10 +54,10 @@ cc_library( "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc index e124d3410..247d8453d 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc @@ -27,9 +27,9 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -49,6 +49,7 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: HandGestureRecognizerSubgraphOptions; using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; @@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { auto inference_output_tensors = inference.Out(kTensorsTag); auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, graph_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, graph_options.classifier_options(), + &postprocessing + .GetOptions())); inference_output_tensors >> postprocessing.In(kTensorsTag); auto classification_result = postprocessing[Output("CLASSIFICATION_RESULT")]; diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD index f3927727e..44ec611b2 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD @@ -26,7 +26,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) @@ -37,7 +37,5 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto index f73443eaf..d8ee95037 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message HandGestureRecognizerSubgraphOptions { @@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions { // Options for configuring the gesture classifier behavior, such as score // threshold, number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be // considered tracked successfully diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index e7c8a6586..dfa77cb96 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -26,11 +26,11 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", @@ -50,9 +50,9 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 1e092e85a..0338b2ee2 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -26,9 +26,9 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/timestamp.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +56,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; // Builds a NormalizedRect covering the entire image. @@ -107,8 +108,8 @@ ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 8ff11413e..24f36017a 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -51,12 +51,14 @@ struct ImageClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function, const Image&, int64)> + std::function, + const Image&, int64)> result_callback = nullptr; }; @@ -113,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. // TODO: describe exact preprocessing steps once // YUVToImageCalculator is integrated. - absl::StatusOr Classify( + absl::StatusOr Classify( mediapipe::Image image, std::optional roi = std::nullopt); @@ -127,9 +129,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - absl::StatusOr ClassifyForVideo( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::StatusOr + ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, + std::optional roi = std::nullopt); // Sends live image data to image classification, and the results will be // available via the "result_callback" provided in the ImageClassifierOptions. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 0d7b60c99..9a0078c5c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -22,11 +22,11 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -43,6 +43,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -152,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the aggregated classification result as the subgraph output diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index edbb851c0..070a5a034 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -32,8 +32,8 @@ limitations under the License. #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -48,6 +48,9 @@ namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; using ::testing::HasSubstr; using ::testing::Optional; diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD index a6f5791e3..29638bebd 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 3da047110..b307a66b6 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message ImageClassifierGraphOptions { @@ -31,5 +31,5 @@ message ImageClassifierGraphOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; } diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 2bc951220..8dd9fcd60 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -31,7 +31,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers:category_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index 00f68e532..0b347fc10 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any -from mediapipe.tasks.cc.components.containers import category_pb2 +from mediapipe.tasks.cc.components.containers.proto import category_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _CategoryProto = category_pb2.Category