Change gesture classification output type to ClassificationList.
PiperOrigin-RevId: 479441969
This commit is contained in:
parent
71a4680a16
commit
65d001c6a0
|
@ -282,6 +282,20 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConfigureClassificationAggregationCalculator(
|
||||||
|
const ModelMetadataExtractor& metadata_extractor,
|
||||||
|
ClassificationAggregationCalculatorOptions* options) {
|
||||||
|
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||||
|
if (output_tensors_metadata == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const auto& metadata : *output_tensors_metadata) {
|
||||||
|
options->add_head_names(metadata->name()->str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
||||||
// classifier options and the (optional) output tensor metadata.
|
// classifier options and the (optional) output tensor metadata.
|
||||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||||
|
@ -333,20 +347,6 @@ absl::Status ConfigureTensorsToClassificationCalculator(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConfigureClassificationAggregationCalculator(
|
|
||||||
const ModelMetadataExtractor& metadata_extractor,
|
|
||||||
ClassificationAggregationCalculatorOptions* options) {
|
|
||||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
|
||||||
if (output_tensors_metadata == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (const auto& metadata : *output_tensors_metadata) {
|
|
||||||
options->add_head_names(metadata->name()->str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
const ModelResources& model_resources,
|
const ModelResources& model_resources,
|
||||||
const proto::ClassifierOptions& classifier_options,
|
const proto::ClassifierOptions& classifier_options,
|
||||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -55,6 +56,16 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
const proto::ClassifierOptions& classifier_options,
|
const proto::ClassifierOptions& classifier_options,
|
||||||
proto::ClassificationPostprocessingGraphOptions* options);
|
proto::ClassificationPostprocessingGraphOptions* options);
|
||||||
|
|
||||||
|
// Utility function to fill in the TensorsToClassificationCalculatorOptions
|
||||||
|
// based on the classifier options and the (optional) output tensor metadata.
|
||||||
|
// This is meant to be used by other graphs that may also rely on this
|
||||||
|
// calculator.
|
||||||
|
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||||
|
const proto::ClassifierOptions& options,
|
||||||
|
const metadata::ModelMetadataExtractor& metadata_extractor,
|
||||||
|
int tensor_index,
|
||||||
|
mediapipe::TensorsToClassificationCalculatorOptions* calculator_options);
|
||||||
|
|
||||||
} // namespace processors
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
|
@ -46,8 +46,9 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||||
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||||
"//mediapipe/calculators/core:get_vector_item_calculator",
|
"//mediapipe/calculators/core:end_loop_calculator",
|
||||||
"//mediapipe/calculators/tensor:tensor_converter_calculator",
|
"//mediapipe/calculators/tensor:tensor_converter_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
|
@ -57,7 +58,6 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
|
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
@ -49,7 +48,8 @@ using ::mediapipe::api2::Input;
|
||||||
using ::mediapipe::api2::Output;
|
using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::processors::
|
||||||
|
ConfigureTensorsToClassificationCalculator;
|
||||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||||
HandGestureRecognizerGraphOptions;
|
HandGestureRecognizerGraphOptions;
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
||||||
// The size of image from which the landmarks detected from.
|
// The size of image from which the landmarks detected from.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// HAND_GESTURES - ClassificationResult
|
// HAND_GESTURES - ClassificationList
|
||||||
// Recognized hand gestures with sorted order such that the winning label is
|
// Recognized hand gestures with sorted order such that the winning label is
|
||||||
// the first item in the list.
|
// the first item in the list.
|
||||||
//
|
//
|
||||||
|
@ -134,12 +134,12 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
||||||
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
|
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
|
||||||
graph[Input<LandmarkList>(kWorldLandmarksTag)],
|
graph[Input<LandmarkList>(kWorldLandmarksTag)],
|
||||||
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
|
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
|
||||||
hand_gestures >> graph[Output<ClassificationResult>(kHandGesturesTag)];
|
hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph(
|
absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
|
||||||
const HandGestureRecognizerGraphOptions& graph_options,
|
const HandGestureRecognizerGraphOptions& graph_options,
|
||||||
const core::ModelResources& model_resources,
|
const core::ModelResources& model_resources,
|
||||||
Source<ClassificationList> handedness,
|
Source<ClassificationList> handedness,
|
||||||
|
@ -205,20 +205,18 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
||||||
concatenated_tensors >> inference.In(kTensorsTag);
|
concatenated_tensors >> inference.In(kTensorsTag);
|
||||||
auto inference_output_tensors = inference.Out(kTensorsTag);
|
auto inference_output_tensors = inference.Out(kTensorsTag);
|
||||||
|
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& tensors_to_classification =
|
||||||
"mediapipe.tasks.components.processors."
|
graph.AddNode("TensorsToClassificationCalculator");
|
||||||
"ClassificationPostprocessingGraph");
|
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
|
||||||
MP_RETURN_IF_ERROR(
|
graph_options.classifier_options(),
|
||||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
*model_resources.GetMetadataExtractor(), 0,
|
||||||
model_resources, graph_options.classifier_options(),
|
&tensors_to_classification.GetOptions<
|
||||||
&postprocessing
|
mediapipe::TensorsToClassificationCalculatorOptions>()));
|
||||||
.GetOptions<components::processors::proto::
|
inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
|
||||||
ClassificationPostprocessingGraphOptions>()));
|
auto classification_list =
|
||||||
inference_output_tensors >> postprocessing.In(kTensorsTag);
|
tensors_to_classification[Output<ClassificationList>(
|
||||||
auto classification_result =
|
"CLASSIFICATIONS")];
|
||||||
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
|
return classification_list;
|
||||||
|
|
||||||
return classification_result;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -246,9 +244,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
||||||
// index corresponding to the same hand if the graph runs multiple times.
|
// index corresponding to the same hand if the graph runs multiple times.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// HAND_GESTURES - std::vector<ClassificationResult>
|
// HAND_GESTURES - std::vector<ClassificationList>
|
||||||
// A vector of recognized hand gestures. Each vector element is the
|
// A vector of recognized hand gestures. Each vector element is the
|
||||||
// ClassificationResult of the hand in input vector.
|
// ClassificationList of the hand in input vector.
|
||||||
//
|
//
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
|
@ -287,12 +285,12 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
||||||
graph[Input<std::pair<int, int>>(kImageSizeTag)],
|
graph[Input<std::pair<int, int>>(kImageSizeTag)],
|
||||||
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
|
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
|
||||||
multi_hand_gestures >>
|
multi_hand_gestures >>
|
||||||
graph[Output<std::vector<ClassificationResult>>(kHandGesturesTag)];
|
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::StatusOr<Source<std::vector<ClassificationResult>>>
|
absl::StatusOr<Source<std::vector<ClassificationList>>>
|
||||||
BuildMultiGestureRecognizerSubraph(
|
BuildMultiGestureRecognizerSubraph(
|
||||||
const HandGestureRecognizerGraphOptions& graph_options,
|
const HandGestureRecognizerGraphOptions& graph_options,
|
||||||
Source<std::vector<ClassificationList>> multi_handedness,
|
Source<std::vector<ClassificationList>> multi_handedness,
|
||||||
|
@ -345,12 +343,13 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
||||||
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
|
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
|
||||||
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
|
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
|
||||||
|
|
||||||
auto& end_loop_classification_results =
|
auto& end_loop_classification_lists =
|
||||||
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator");
|
graph.AddNode("EndLoopClassificationListCalculator");
|
||||||
batch_end >> end_loop_classification_results.In(kBatchEndTag);
|
batch_end >> end_loop_classification_lists.In(kBatchEndTag);
|
||||||
hand_gestures >> end_loop_classification_results.In(kItemTag);
|
hand_gestures >> end_loop_classification_lists.In(kItemTag);
|
||||||
auto multi_hand_gestures = end_loop_classification_results
|
auto multi_hand_gestures =
|
||||||
[Output<std::vector<ClassificationResult>>(kIterableTag)];
|
end_loop_classification_lists[Output<std::vector<ClassificationList>>(
|
||||||
|
kIterableTag)];
|
||||||
|
|
||||||
return multi_hand_gestures;
|
return multi_hand_gestures;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user