Make NORM_RECT optional for GestureRecognizerGraph and add PALM_DETECTION output PORT

PiperOrigin-RevId: 505712542
This commit is contained in:
MediaPipe Team 2023-01-30 09:15:06 -08:00 committed by Copybara-Service
parent ee2f940e1f
commit f9f6acffed
5 changed files with 58 additions and 28 deletions

View File

@ -140,6 +140,7 @@ cc_library(
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -68,6 +69,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
constexpr char kRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
constexpr char kPalmRectsTag[] = "PALM_RECTS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task";
constexpr char kHandGestureRecognizerBundleAssetName[] =
"hand_gesture_recognizer.task";
@ -77,6 +81,9 @@ struct GestureRecognizerOutputs {
Source<std::vector<ClassificationList>> handedness;
Source<std::vector<NormalizedLandmarkList>> hand_landmarks;
Source<std::vector<LandmarkList>> hand_world_landmarks;
Source<std::vector<NormalizedRect>> hand_rects_next_frame;
Source<std::vector<NormalizedRect>> palm_rects;
Source<std::vector<Detection>> palm_detections;
Source<Image> image;
};
@ -135,9 +142,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// Inputs:
// IMAGE - Image
// Image to perform hand gesture recognition on.
// NORM_RECT - NormalizedRect
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform landmarks
// detection on.
// detection on. If not provided, whole image is used for gesture
// recognition.
//
// Outputs:
// HAND_GESTURES - std::vector<ClassificationList>
@ -208,11 +216,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
ASSIGN_OR_RETURN(
auto hand_gesture_recognition_output,
BuildGestureRecognizerGraph(
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_gesture_recognition_output.gesture >>
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
hand_gesture_recognition_output.handedness >>
@ -222,6 +231,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
hand_gesture_recognition_output.hand_world_landmarks >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)];
hand_gesture_recognition_output.hand_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kRectNextFrameTag)];
hand_gesture_recognition_output.palm_rects >>
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
hand_gesture_recognition_output.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
return graph.GetConfig();
}
@ -279,7 +294,17 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
/*handedness=*/handedness,
/*hand_landmarks=*/hand_landmarks,
/*hand_world_landmarks=*/hand_world_landmarks,
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
/*hand_rects_next_frame =*/
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
kRectNextFrameTag)],
/*palm_rects =*/
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
kPalmRectsTag)],
/*palm_detections =*/
hand_landmarker_graph[Output<std::vector<Detection>>(
kPalmDetectionsTag)],
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)],
};
}
};

View File

@ -150,9 +150,9 @@ void ConfigureRectTransformationCalculator(
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - NormalizedRect
// Describes image rotation and region of image to perform detection
// on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection on. If
// not provided, whole image is used for hand detection.
//
// Outputs:
// PALM_DETECTIONS - std::vector<Detection>
@ -197,11 +197,12 @@ class HandDetectorGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<HandDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto hand_detection_outs,
ASSIGN_OR_RETURN(
auto hand_detection_outs,
BuildHandDetectionSubgraph(
sc->Options<HandDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
sc->Options<HandDetectorGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_detection_outs.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_detection_outs.hand_rects >>

View File

@ -136,9 +136,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// Inputs:
// IMAGE - Image
// Image to perform hand landmarks detection on.
// NORM_RECT - NormalizedRect
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform landmarks
// detection on.
// detection on. If not provided, whole image is used for hand landmarks
// detection.
//
// Outputs:
// LANDMARKS: - std::vector<NormalizedLandmarkList>
@ -218,11 +219,12 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(auto hand_landmarker_outputs,
ASSIGN_OR_RETURN(
auto hand_landmarker_outputs,
BuildHandLandmarkerGraph(
sc->Options<HandLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_landmarker_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_outputs.world_landmark_lists >>

View File

@ -243,11 +243,12 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
const auto* model_resources,
CreateModelResources<HandLandmarksDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto hand_landmark_detection_outs,
ASSIGN_OR_RETURN(
auto hand_landmark_detection_outs,
BuildSingleHandLandmarksDetectorGraph(
sc->Options<HandLandmarksDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kHandRectTag)], graph));
sc->Options<HandLandmarksDetectorGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kHandRectTag)], graph));
hand_landmark_detection_outs.hand_landmarks >>
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
hand_landmark_detection_outs.world_hand_landmarks >>