Change gesture classification output type to ClassificationList.

PiperOrigin-RevId: 479441969
This commit is contained in:
MediaPipe Team 2022-10-06 16:36:26 -07:00 committed by Copybara-Service
parent 71a4680a16
commit 65d001c6a0
4 changed files with 55 additions and 45 deletions

View File

@ -282,6 +282,20 @@ absl::Status ConfigureScoreCalibrationIfAny(
return absl::OkStatus();
}
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
}
for (const auto& metadata : *output_tensors_metadata) {
options->add_head_names(metadata->name()->str());
}
}
} // namespace
// Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator(
@ -333,20 +347,6 @@ absl::Status ConfigureTensorsToClassificationCalculator(
return absl::OkStatus();
}
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
}
for (const auto& metadata : *output_tensors_metadata) {
options->add_head_names(metadata->name()->str());
}
}
} // namespace
absl::Status ConfigureClassificationPostprocessingGraph(
const ModelResources& model_resources,
const proto::ClassifierOptions& classifier_options,

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace tasks {
@ -55,6 +56,16 @@ absl::Status ConfigureClassificationPostprocessingGraph(
const proto::ClassifierOptions& classifier_options,
proto::ClassificationPostprocessingGraphOptions* options);
// Utility function to fill in the TensorsToClassificationCalculatorOptions
// based on the classifier options and the (optional) output tensor metadata.
// This is meant to be used by other graphs that may also rely on this
// calculator.
absl::Status ConfigureTensorsToClassificationCalculator(
const proto::ClassifierOptions& options,
const metadata::ModelMetadataExtractor& metadata_extractor,
int tensor_index,
mediapipe::TensorsToClassificationCalculatorOptions* calculator_options);
} // namespace processors
} // namespace components
} // namespace tasks

View File

@ -46,8 +46,9 @@ cc_library(
deps = [
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:get_vector_item_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/tensor:tensor_converter_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
@ -57,7 +58,6 @@ cc_library(
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
@ -49,7 +48,8 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::processors::
ConfigureTensorsToClassificationCalculator;
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions;
@ -94,7 +94,7 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
// The size of image from which the landmarks detected from.
//
// Outputs:
// HAND_GESTURES - ClassificationResult
// HAND_GESTURES - ClassificationList
// Recognized hand gestures with sorted order such that the winning label is
// the first item in the list.
//
@ -134,12 +134,12 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
graph[Input<LandmarkList>(kWorldLandmarksTag)],
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
hand_gestures >> graph[Output<ClassificationResult>(kHandGesturesTag)];
hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph(
absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
const HandGestureRecognizerGraphOptions& graph_options,
const core::ModelResources& model_resources,
Source<ClassificationList> handedness,
@ -205,20 +205,18 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
concatenated_tensors >> inference.In(kTensorsTag);
auto inference_output_tensors = inference.Out(kTensorsTag);
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, graph_options.classifier_options(),
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference_output_tensors >> postprocessing.In(kTensorsTag);
auto classification_result =
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
return classification_result;
auto& tensors_to_classification =
graph.AddNode("TensorsToClassificationCalculator");
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
graph_options.classifier_options(),
*model_resources.GetMetadataExtractor(), 0,
&tensors_to_classification.GetOptions<
mediapipe::TensorsToClassificationCalculatorOptions>()));
inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
auto classification_list =
tensors_to_classification[Output<ClassificationList>(
"CLASSIFICATIONS")];
return classification_list;
}
};
@ -246,9 +244,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// index corresponding to the same hand if the graph runs multiple times.
//
// Outputs:
// HAND_GESTURES - std::vector<ClassificationResult>
// HAND_GESTURES - std::vector<ClassificationList>
// A vector of recognized hand gestures. Each vector element is the
// ClassificationResult of the hand in input vector.
// ClassificationList of the hand in input vector.
//
//
// Example:
@ -287,12 +285,12 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<std::pair<int, int>>(kImageSizeTag)],
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
multi_hand_gestures >>
graph[Output<std::vector<ClassificationResult>>(kHandGesturesTag)];
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<Source<std::vector<ClassificationResult>>>
absl::StatusOr<Source<std::vector<ClassificationList>>>
BuildMultiGestureRecognizerSubraph(
const HandGestureRecognizerGraphOptions& graph_options,
Source<std::vector<ClassificationList>> multi_handedness,
@ -345,12 +343,13 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
auto& end_loop_classification_results =
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator");
batch_end >> end_loop_classification_results.In(kBatchEndTag);
hand_gestures >> end_loop_classification_results.In(kItemTag);
auto multi_hand_gestures = end_loop_classification_results
[Output<std::vector<ClassificationResult>>(kIterableTag)];
auto& end_loop_classification_lists =
graph.AddNode("EndLoopClassificationListCalculator");
batch_end >> end_loop_classification_lists.In(kBatchEndTag);
hand_gestures >> end_loop_classification_lists.In(kItemTag);
auto multi_hand_gestures =
end_loop_classification_lists[Output<std::vector<ClassificationList>>(
kIterableTag)];
return multi_hand_gestures;
}