Make NORM_RECT optional for GestureRecognizerGraph and add PALM_DETECTION output PORT

PiperOrigin-RevId: 505712542
This commit is contained in:
MediaPipe Team 2023-01-30 09:15:06 -08:00 committed by Copybara-Service
parent ee2f940e1f
commit f9f6acffed
5 changed files with 58 additions and 28 deletions

View File

@ -140,6 +140,7 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
@ -68,6 +69,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
constexpr char kRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
constexpr char kPalmRectsTag[] = "PALM_RECTS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task"; constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task";
constexpr char kHandGestureRecognizerBundleAssetName[] = constexpr char kHandGestureRecognizerBundleAssetName[] =
"hand_gesture_recognizer.task"; "hand_gesture_recognizer.task";
@ -77,6 +81,9 @@ struct GestureRecognizerOutputs {
Source<std::vector<ClassificationList>> handedness; Source<std::vector<ClassificationList>> handedness;
Source<std::vector<NormalizedLandmarkList>> hand_landmarks; Source<std::vector<NormalizedLandmarkList>> hand_landmarks;
Source<std::vector<LandmarkList>> hand_world_landmarks; Source<std::vector<LandmarkList>> hand_world_landmarks;
Source<std::vector<NormalizedRect>> hand_rects_next_frame;
Source<std::vector<NormalizedRect>> palm_rects;
Source<std::vector<Detection>> palm_detections;
Source<Image> image; Source<Image> image;
}; };
@ -135,9 +142,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
// Image to perform hand gesture recognition on. // Image to perform hand gesture recognition on.
// NORM_RECT - NormalizedRect // NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform landmarks // Describes image rotation and region of image to perform landmarks
// detection on. // detection on. If not provided, whole image is used for gesture
// recognition.
// //
// Outputs: // Outputs:
// HAND_GESTURES - std::vector<ClassificationList> // HAND_GESTURES - std::vector<ClassificationList>
@ -208,11 +216,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable())); .IsAvailable()));
} }
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, ASSIGN_OR_RETURN(
BuildGestureRecognizerGraph( auto hand_gesture_recognition_output,
*sc->MutableOptions<GestureRecognizerGraphOptions>(), BuildGestureRecognizerGraph(
graph[Input<Image>(kImageTag)], *sc->MutableOptions<GestureRecognizerGraphOptions>(),
graph[Input<NormalizedRect>(kNormRectTag)], graph)); graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_gesture_recognition_output.gesture >> hand_gesture_recognition_output.gesture >>
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)]; graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
hand_gesture_recognition_output.handedness >> hand_gesture_recognition_output.handedness >>
@ -222,6 +231,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
hand_gesture_recognition_output.hand_world_landmarks >> hand_gesture_recognition_output.hand_world_landmarks >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)]; graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)]; hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)];
hand_gesture_recognition_output.hand_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kRectNextFrameTag)];
hand_gesture_recognition_output.palm_rects >>
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
hand_gesture_recognition_output.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -279,7 +294,17 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
/*handedness=*/handedness, /*handedness=*/handedness,
/*hand_landmarks=*/hand_landmarks, /*hand_landmarks=*/hand_landmarks,
/*hand_world_landmarks=*/hand_world_landmarks, /*hand_world_landmarks=*/hand_world_landmarks,
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]}; /*hand_rects_next_frame =*/
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
kRectNextFrameTag)],
/*palm_rects =*/
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
kPalmRectsTag)],
/*palm_detections =*/
hand_landmarker_graph[Output<std::vector<Detection>>(
kPalmDetectionsTag)],
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)],
};
} }
}; };

View File

@ -150,9 +150,9 @@ void ConfigureRectTransformationCalculator(
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
// Image to perform detection on. // Image to perform detection on.
// NORM_RECT - NormalizedRect // NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection // Describes image rotation and region of image to perform detection on. If
// on. // not provided, whole image is used for hand detection.
// //
// Outputs: // Outputs:
// PALM_DETECTIONS - std::vector<Detection> // PALM_DETECTIONS - std::vector<Detection>
@ -197,11 +197,12 @@ class HandDetectorGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<HandDetectorGraphOptions>(sc)); CreateModelResources<HandDetectorGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN(auto hand_detection_outs, ASSIGN_OR_RETURN(
BuildHandDetectionSubgraph( auto hand_detection_outs,
sc->Options<HandDetectorGraphOptions>(), BuildHandDetectionSubgraph(
*model_resources, graph[Input<Image>(kImageTag)], sc->Options<HandDetectorGraphOptions>(), *model_resources,
graph[Input<NormalizedRect>(kNormRectTag)], graph)); graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_detection_outs.palm_detections >> hand_detection_outs.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)]; graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_detection_outs.hand_rects >> hand_detection_outs.hand_rects >>

View File

@ -136,9 +136,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
// Image to perform hand landmarks detection on. // Image to perform hand landmarks detection on.
// NORM_RECT - NormalizedRect // NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform landmarks // Describes image rotation and region of image to perform landmarks
// detection on. // detection on. If not provided, whole image is used for hand landmarks
// detection.
// //
// Outputs: // Outputs:
// LANDMARKS: - std::vector<NormalizedLandmarkList> // LANDMARKS: - std::vector<NormalizedLandmarkList>
@ -218,11 +219,12 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable())); .IsAvailable()));
} }
ASSIGN_OR_RETURN(auto hand_landmarker_outputs, ASSIGN_OR_RETURN(
BuildHandLandmarkerGraph( auto hand_landmarker_outputs,
sc->Options<HandLandmarkerGraphOptions>(), BuildHandLandmarkerGraph(
graph[Input<Image>(kImageTag)], sc->Options<HandLandmarkerGraphOptions>(),
graph[Input<NormalizedRect>(kNormRectTag)], graph)); graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_landmarker_outputs.landmark_lists >> hand_landmarker_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)]; graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_outputs.world_landmark_lists >> hand_landmarker_outputs.world_landmark_lists >>

View File

@ -243,11 +243,12 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
const auto* model_resources, const auto* model_resources,
CreateModelResources<HandLandmarksDetectorGraphOptions>(sc)); CreateModelResources<HandLandmarksDetectorGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN(auto hand_landmark_detection_outs, ASSIGN_OR_RETURN(
BuildSingleHandLandmarksDetectorGraph( auto hand_landmark_detection_outs,
sc->Options<HandLandmarksDetectorGraphOptions>(), BuildSingleHandLandmarksDetectorGraph(
*model_resources, graph[Input<Image>(kImageTag)], sc->Options<HandLandmarksDetectorGraphOptions>(), *model_resources,
graph[Input<NormalizedRect>(kHandRectTag)], graph)); graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kHandRectTag)], graph));
hand_landmark_detection_outs.hand_landmarks >> hand_landmark_detection_outs.hand_landmarks >>
graph[Output<NormalizedLandmarkList>(kLandmarksTag)]; graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
hand_landmark_detection_outs.world_hand_landmarks >> hand_landmark_detection_outs.world_hand_landmarks >>