Change gesture classification output type to ClassificationList.
PiperOrigin-RevId: 479441969
This commit is contained in:
parent
71a4680a16
commit
65d001c6a0
|
@ -282,6 +282,20 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ConfigureClassificationAggregationCalculator(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
ClassificationAggregationCalculatorOptions* options) {
|
||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (const auto& metadata : *output_tensors_metadata) {
|
||||
options->add_head_names(metadata->name()->str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
||||
// classifier options and the (optional) output tensor metadata.
|
||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||
|
@ -333,20 +347,6 @@ absl::Status ConfigureTensorsToClassificationCalculator(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ConfigureClassificationAggregationCalculator(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
ClassificationAggregationCalculatorOptions* options) {
|
||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (const auto& metadata : *output_tensors_metadata) {
|
||||
options->add_head_names(metadata->name()->str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const ModelResources& model_resources,
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -55,6 +56,16 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
|||
const proto::ClassifierOptions& classifier_options,
|
||||
proto::ClassificationPostprocessingGraphOptions* options);
|
||||
|
||||
// Utility function to fill in the TensorsToClassificationCalculatorOptions
|
||||
// based on the classifier options and the (optional) output tensor metadata.
|
||||
// This is meant to be used by other graphs that may also rely on this
|
||||
// calculator.
|
||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||
const proto::ClassifierOptions& options,
|
||||
const metadata::ModelMetadataExtractor& metadata_extractor,
|
||||
int tensor_index,
|
||||
mediapipe::TensorsToClassificationCalculatorOptions* calculator_options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
|
|
|
@ -46,8 +46,9 @@ cc_library(
|
|||
deps = [
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||
"//mediapipe/calculators/core:get_vector_item_calculator",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/tensor:tensor_converter_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
|
@ -57,7 +58,6 @@ cc_library(
|
|||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
|
|
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -49,7 +48,8 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::processors::
|
||||
ConfigureTensorsToClassificationCalculator;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
HandGestureRecognizerGraphOptions;
|
||||
|
||||
|
@ -94,7 +94,7 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
// The size of image from which the landmarks detected from.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - ClassificationResult
|
||||
// HAND_GESTURES - ClassificationList
|
||||
// Recognized hand gestures with sorted order such that the winning label is
|
||||
// the first item in the list.
|
||||
//
|
||||
|
@ -134,12 +134,12 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
|
||||
graph[Input<LandmarkList>(kWorldLandmarksTag)],
|
||||
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
|
||||
hand_gestures >> graph[Output<ClassificationResult>(kHandGesturesTag)];
|
||||
hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph(
|
||||
absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
|
||||
const HandGestureRecognizerGraphOptions& graph_options,
|
||||
const core::ModelResources& model_resources,
|
||||
Source<ClassificationList> handedness,
|
||||
|
@ -205,20 +205,18 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
concatenated_tensors >> inference.In(kTensorsTag);
|
||||
auto inference_output_tensors = inference.Out(kTensorsTag);
|
||||
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, graph_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference_output_tensors >> postprocessing.In(kTensorsTag);
|
||||
auto classification_result =
|
||||
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
|
||||
|
||||
return classification_result;
|
||||
auto& tensors_to_classification =
|
||||
graph.AddNode("TensorsToClassificationCalculator");
|
||||
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
|
||||
graph_options.classifier_options(),
|
||||
*model_resources.GetMetadataExtractor(), 0,
|
||||
&tensors_to_classification.GetOptions<
|
||||
mediapipe::TensorsToClassificationCalculatorOptions>()));
|
||||
inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
|
||||
auto classification_list =
|
||||
tensors_to_classification[Output<ClassificationList>(
|
||||
"CLASSIFICATIONS")];
|
||||
return classification_list;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -246,9 +244,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// index corresponding to the same hand if the graph runs multiple times.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - std::vector<ClassificationResult>
|
||||
// HAND_GESTURES - std::vector<ClassificationList>
|
||||
// A vector of recognized hand gestures. Each vector element is the
|
||||
// ClassificationResult of the hand in input vector.
|
||||
// ClassificationList of the hand in input vector.
|
||||
//
|
||||
//
|
||||
// Example:
|
||||
|
@ -287,12 +285,12 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
graph[Input<std::pair<int, int>>(kImageSizeTag)],
|
||||
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
|
||||
multi_hand_gestures >>
|
||||
graph[Output<std::vector<ClassificationResult>>(kHandGesturesTag)];
|
||||
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<Source<std::vector<ClassificationResult>>>
|
||||
absl::StatusOr<Source<std::vector<ClassificationList>>>
|
||||
BuildMultiGestureRecognizerSubraph(
|
||||
const HandGestureRecognizerGraphOptions& graph_options,
|
||||
Source<std::vector<ClassificationList>> multi_handedness,
|
||||
|
@ -345,12 +343,13 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
|
||||
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
|
||||
|
||||
auto& end_loop_classification_results =
|
||||
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator");
|
||||
batch_end >> end_loop_classification_results.In(kBatchEndTag);
|
||||
hand_gestures >> end_loop_classification_results.In(kItemTag);
|
||||
auto multi_hand_gestures = end_loop_classification_results
|
||||
[Output<std::vector<ClassificationResult>>(kIterableTag)];
|
||||
auto& end_loop_classification_lists =
|
||||
graph.AddNode("EndLoopClassificationListCalculator");
|
||||
batch_end >> end_loop_classification_lists.In(kBatchEndTag);
|
||||
hand_gestures >> end_loop_classification_lists.In(kItemTag);
|
||||
auto multi_hand_gestures =
|
||||
end_loop_classification_lists[Output<std::vector<ClassificationList>>(
|
||||
kIterableTag)];
|
||||
|
||||
return multi_hand_gestures;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user