Change gesture classification output type to ClassificationList.

PiperOrigin-RevId: 479441969
This commit is contained in:
MediaPipe Team 2022-10-06 16:36:26 -07:00 committed by Copybara-Service
parent 71a4680a16
commit 65d001c6a0
4 changed files with 55 additions and 45 deletions

View File

@ -282,6 +282,20 @@ absl::Status ConfigureScoreCalibrationIfAny(
return absl::OkStatus(); return absl::OkStatus();
} }
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
}
for (const auto& metadata : *output_tensors_metadata) {
options->add_head_names(metadata->name()->str());
}
}
} // namespace
// Fills in the TensorsToClassificationCalculatorOptions based on the // Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata. // classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator( absl::Status ConfigureTensorsToClassificationCalculator(
@ -333,20 +347,6 @@ absl::Status ConfigureTensorsToClassificationCalculator(
return absl::OkStatus(); return absl::OkStatus();
} }
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
}
for (const auto& metadata : *output_tensors_metadata) {
options->add_head_names(metadata->name()->str());
}
}
} // namespace
absl::Status ConfigureClassificationPostprocessingGraph( absl::Status ConfigureClassificationPostprocessingGraph(
const ModelResources& model_resources, const ModelResources& model_resources,
const proto::ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -55,6 +56,16 @@ absl::Status ConfigureClassificationPostprocessingGraph(
const proto::ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,
proto::ClassificationPostprocessingGraphOptions* options); proto::ClassificationPostprocessingGraphOptions* options);
// Utility function to fill in the TensorsToClassificationCalculatorOptions
// based on the classifier options and the (optional) output tensor metadata.
// This is meant to be used by other graphs that may also rely on this
// calculator.
absl::Status ConfigureTensorsToClassificationCalculator(
const proto::ClassifierOptions& options,
const metadata::ModelMetadataExtractor& metadata_extractor,
int tensor_index,
mediapipe::TensorsToClassificationCalculatorOptions* calculator_options);
} // namespace processors } // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks

View File

@ -46,8 +46,9 @@ cc_library(
deps = [ deps = [
"//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator", "//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:get_vector_item_calculator", "//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/tensor:tensor_converter_calculator", "//mediapipe/calculators/tensor:tensor_converter_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
@ -57,7 +58,6 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
@ -49,7 +48,8 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::processors::
ConfigureTensorsToClassificationCalculator;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions; HandGestureRecognizerGraphOptions;
@ -94,7 +94,7 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
// The size of image from which the landmarks detected from. // The size of image from which the landmarks detected from.
// //
// Outputs: // Outputs:
// HAND_GESTURES - ClassificationResult // HAND_GESTURES - ClassificationList
// Recognized hand gestures with sorted order such that the winning label is // Recognized hand gestures with sorted order such that the winning label is
// the first item in the list. // the first item in the list.
// //
@ -134,12 +134,12 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<NormalizedLandmarkList>(kLandmarksTag)], graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
graph[Input<LandmarkList>(kWorldLandmarksTag)], graph[Input<LandmarkList>(kWorldLandmarksTag)],
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph)); graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
hand_gestures >> graph[Output<ClassificationResult>(kHandGesturesTag)]; hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph( absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
const HandGestureRecognizerGraphOptions& graph_options, const HandGestureRecognizerGraphOptions& graph_options,
const core::ModelResources& model_resources, const core::ModelResources& model_resources,
Source<ClassificationList> handedness, Source<ClassificationList> handedness,
@ -205,20 +205,18 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
concatenated_tensors >> inference.In(kTensorsTag); concatenated_tensors >> inference.In(kTensorsTag);
auto inference_output_tensors = inference.Out(kTensorsTag); auto inference_output_tensors = inference.Out(kTensorsTag);
auto& postprocessing = graph.AddNode( auto& tensors_to_classification =
"mediapipe.tasks.components.processors." graph.AddNode("TensorsToClassificationCalculator");
"ClassificationPostprocessingGraph"); MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
MP_RETURN_IF_ERROR( graph_options.classifier_options(),
components::processors::ConfigureClassificationPostprocessingGraph( *model_resources.GetMetadataExtractor(), 0,
model_resources, graph_options.classifier_options(), &tensors_to_classification.GetOptions<
&postprocessing mediapipe::TensorsToClassificationCalculatorOptions>()));
.GetOptions<components::processors::proto:: inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
ClassificationPostprocessingGraphOptions>())); auto classification_list =
inference_output_tensors >> postprocessing.In(kTensorsTag); tensors_to_classification[Output<ClassificationList>(
auto classification_result = "CLASSIFICATIONS")];
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")]; return classification_list;
return classification_result;
} }
}; };
@ -246,9 +244,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// index corresponding to the same hand if the graph runs multiple times. // index corresponding to the same hand if the graph runs multiple times.
// //
// Outputs: // Outputs:
// HAND_GESTURES - std::vector<ClassificationResult> // HAND_GESTURES - std::vector<ClassificationList>
// A vector of recognized hand gestures. Each vector element is the // A vector of recognized hand gestures. Each vector element is the
// ClassificationResult of the hand in input vector. // ClassificationList of the hand in input vector.
// //
// //
// Example: // Example:
@ -287,12 +285,12 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph[Input<std::pair<int, int>>(kImageSizeTag)],
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph)); graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
multi_hand_gestures >> multi_hand_gestures >>
graph[Output<std::vector<ClassificationResult>>(kHandGesturesTag)]; graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
absl::StatusOr<Source<std::vector<ClassificationResult>>> absl::StatusOr<Source<std::vector<ClassificationList>>>
BuildMultiGestureRecognizerSubraph( BuildMultiGestureRecognizerSubraph(
const HandGestureRecognizerGraphOptions& graph_options, const HandGestureRecognizerGraphOptions& graph_options,
Source<std::vector<ClassificationList>> multi_handedness, Source<std::vector<ClassificationList>> multi_handedness,
@ -345,12 +343,13 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
auto& end_loop_classification_results = auto& end_loop_classification_lists =
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator"); graph.AddNode("EndLoopClassificationListCalculator");
batch_end >> end_loop_classification_results.In(kBatchEndTag); batch_end >> end_loop_classification_lists.In(kBatchEndTag);
hand_gestures >> end_loop_classification_results.In(kItemTag); hand_gestures >> end_loop_classification_lists.In(kItemTag);
auto multi_hand_gestures = end_loop_classification_results auto multi_hand_gestures =
[Output<std::vector<ClassificationResult>>(kIterableTag)]; end_loop_classification_lists[Output<std::vector<ClassificationList>>(
kIterableTag)];
return multi_hand_gestures; return multi_hand_gestures;
} }