Internal change

PiperOrigin-RevId: 479235176
This commit is contained in:
MediaPipe Team 2022-10-05 23:54:46 -07:00 committed by Copybara-Service
parent 42978d3e69
commit b72219f1e5
4 changed files with 252 additions and 3 deletions

View File

@ -44,7 +44,9 @@ cc_library(
name = "hand_gesture_recognizer_graph",
srcs = ["hand_gesture_recognizer_graph.cc"],
deps = [
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:get_vector_item_calculator",
"//mediapipe/calculators/tensor:tensor_converter_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
@ -74,3 +76,35 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "gesture_recognizer_graph",
srcs = ["gesture_recognizer_graph.cc"],
deps = [
":hand_gesture_recognizer_graph",
"//mediapipe/calculators/core:vector_indices_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)

View File

@ -0,0 +1,215 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <type_traits>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
GestureRecognizerGraphOptions;
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarkerGraphOptions;
constexpr char kImageTag[] = "IMAGE";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
struct GestureRecognizerOutputs {
Source<std::vector<ClassificationResult>> gesture;
Source<std::vector<mediapipe::ClassificationList>> handedness;
Source<std::vector<mediapipe::NormalizedLandmarkList>> hand_landmarks;
Source<std::vector<mediapipe::LandmarkList>> hand_world_landmarks;
Source<Image> image;
};
} // namespace
// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs
// hand gesture recognition.
//
// Inputs:
// IMAGE - Image
// Image to perform hand gesture recognition on.
//
// Outputs:
// HAND_GESTURES - std::vector<ClassificationResult>
// Recognized hand gestures with sorted order such that the winning label is
// the first item in the list.
// LANDMARKS: - std::vector<NormalizedLandmarkList>
// Detected hand landmarks.
// WORLD_LANDMARKS - std::vector<LandmarkList>
// Detected hand landmarks in world coordinates.
// HAND_RECT_NEXT_FRAME - std::vector<NormalizedRect>
// The predicted Rect enclosing the hand RoI for landmark detection on the
// next frame.
// HANDEDNESS - std::vector<ClassificationList>
// Classification of handedness.
// IMAGE - mediapipe::Image
// The image that gesture recognizer runs on and has the pixel data stored
// on the target storage (CPU vs GPU).
//
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"
// input_stream: "IMAGE:image_in"
// output_stream: "HAND_GESTURES:hand_gestures"
// output_stream: "LANDMARKS:hand_landmarks"
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
// output_stream: "HANDEDNESS:handedness"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.vision.gesture_recognizer.proto.GestureRecognizerGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "hand_gesture.tflite"
// }
// }
// hand_landmark_detector_options {
// base_options {
// model_asset {
// file_name: "hand_landmark.tflite"
// }
// }
// }
// }
// }
// }
class GestureRecognizerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
BuildGestureRecognizerGraph(
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
graph[Input<Image>(kImageTag)], graph));
hand_gesture_recognition_output.gesture >>
graph[Output<std::vector<ClassificationResult>>(kHandGesturesTag)];
hand_gesture_recognition_output.handedness >>
graph[Output<std::vector<mediapipe::ClassificationList>>(
kHandednessTag)];
hand_gesture_recognition_output.hand_landmarks >>
graph[Output<std::vector<mediapipe::NormalizedLandmarkList>>(
kLandmarksTag)];
hand_gesture_recognition_output.hand_world_landmarks >>
graph[Output<std::vector<mediapipe::LandmarkList>>(kWorldLandmarksTag)];
hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<GestureRecognizerOutputs> BuildGestureRecognizerGraph(
GestureRecognizerGraphOptions& graph_options, Source<Image> image_in,
Graph& graph) {
auto& image_property = graph.AddNode("ImagePropertiesCalculator");
image_in >> image_property.In("IMAGE");
auto image_size = image_property.Out("SIZE");
// Hand landmarker graph.
auto& hand_landmarker_graph = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
auto& hand_landmarker_graph_options =
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
hand_landmarker_graph_options.Swap(
graph_options.mutable_hand_landmarker_graph_options());
image_in >> hand_landmarker_graph.In(kImageTag);
auto hand_landmarks =
hand_landmarker_graph[Output<std::vector<NormalizedLandmarkList>>(
kLandmarksTag)];
auto hand_world_landmarks =
hand_landmarker_graph[Output<std::vector<LandmarkList>>(
kWorldLandmarksTag)];
auto handedness =
hand_landmarker_graph[Output<std::vector<ClassificationList>>(
kHandednessTag)];
auto& vector_indices =
graph.AddNode("NormalizedLandmarkListVectorIndicesCalculator");
hand_landmarks >> vector_indices.In("VECTOR");
auto hand_landmarks_id = vector_indices.Out("INDICES");
// Hand gesture recognizer subgraph.
auto& hand_gesture_subgraph = graph.AddNode(
"mediapipe.tasks.vision.gesture_recognizer."
"MultipleHandGestureRecognizerGraph");
hand_gesture_subgraph.GetOptions<HandGestureRecognizerGraphOptions>().Swap(
graph_options.mutable_hand_gesture_recognizer_graph_options());
hand_landmarks >> hand_gesture_subgraph.In(kLandmarksTag);
hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag);
handedness >> hand_gesture_subgraph.In(kHandednessTag);
image_size >> hand_gesture_subgraph.In(kImageSizeTag);
hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag);
auto hand_gestures =
hand_gesture_subgraph[Output<std::vector<ClassificationResult>>(
kHandGesturesTag)];
return {{.gesture = hand_gestures,
.handedness = handedness,
.hand_landmarks = hand_landmarks,
.hand_world_landmarks = hand_world_landmarks,
.image = hand_landmarker_graph[Output<Image>(kImageTag)]}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::gesture_recognizer::GestureRecognizerGraph); // NOLINT
// clang-format on
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -36,7 +36,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
@ -200,6 +199,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
auto concatenated_tensors = concatenate_tensor_vector.Out("");
// Inference for static hand gesture recognition.
// TODO add embedding step.
auto& inference = AddInference(
model_resources, graph_options.base_options().acceleration(), graph);
concatenated_tensors >> inference.In(kTensorsTag);

View File

@ -24,13 +24,13 @@ import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_op
message GestureRecognizerGraphOptions {
extend mediapipe.CalculatorOptions {
optional GestureRecognizerGraphOptions ext = 478371831;
optional GestureRecognizerGraphOptions ext = 479097054;
}
// Base options for configuring gesture recognizer graph, such as specifying
// the TfLite model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for configuring hand landmarker subgraph.
// Options for configuring hand landmarker graph.
optional hand_landmarker.proto.HandLandmarkerGraphOptions
hand_landmarker_graph_options = 2;