diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index d7163e331..923eab1ca 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -257,19 +257,28 @@ class HandDetectorGraph : public core::ModelTaskGraph { preprocessed_tensors >> inference.In("TENSORS"); auto model_output_tensors = inference.Out("TENSORS"); + // TODO: support hand detection metadata. + bool has_metadata = false; + // Generates a single side packet containing a vector of SSD anchors. auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); - ConfigureSsdAnchorsCalculator( - &ssd_anchor.GetOptions()); + auto& ssd_anchor_options = + ssd_anchor.GetOptions(); + if (!has_metadata) { + ConfigureSsdAnchorsCalculator(&ssd_anchor_options); + } auto anchors = ssd_anchor.SideOut(""); // Converts output tensors to Detections. auto& tensors_to_detections = graph.AddNode("TensorsToDetectionsCalculator"); - ConfigureTensorsToDetectionsCalculator( - subgraph_options, - &tensors_to_detections - .GetOptions()); + if (!has_metadata) { + ConfigureTensorsToDetectionsCalculator( + subgraph_options, + &tensors_to_detections + .GetOptions()); + } + model_output_tensors >> tensors_to_detections.In("TENSORS"); anchors >> tensors_to_detections.SideIn("ANCHORS"); auto detections = tensors_to_detections.Out("DETECTIONS"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 2552e7a10..7a83816b8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -148,6 +148,7 @@ cc_library( "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_landmarks_deduplication_calculator", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/util:graph_builder_utils", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 74d288ac1..4a3db9f4d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -41,6 +42,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/util/graph_builder_utils.h" namespace mediapipe { namespace tasks { @@ -53,7 +55,7 @@ using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; -using ::mediapipe::api2::builder::Source; +using ::mediapipe::api2::builder::Stream; using ::mediapipe::tasks::components::utils::DisallowIf; using ::mediapipe::tasks::core::ModelAssetBundleResources; using ::mediapipe::tasks::metadata::SetExternalFile; @@ -78,40 +80,46 @@ constexpr char kHandLandmarksDetectorTFLiteName[] = "hand_landmarks_detector.tflite"; struct HandLandmarkerOutputs { - Source> landmark_lists; - Source> world_landmark_lists; - Source> hand_rects_next_frame; - Source> handednesses; - Source> palm_rects; - Source> palm_detections; - Source image; + Stream> landmark_lists; + Stream> world_landmark_lists; + Stream> hand_rects_next_frame; + Stream> handednesses; + Stream> palm_rects; + Stream> palm_detections; + Stream image; }; // Sets the base options in the sub tasks. absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, HandLandmarkerGraphOptions* options, bool is_copy) { - ASSIGN_OR_RETURN(const auto hand_detector_file, - resources.GetModelFile(kHandDetectorTFLiteName)); auto* hand_detector_graph_options = options->mutable_hand_detector_graph_options(); - SetExternalFile(hand_detector_file, - hand_detector_graph_options->mutable_base_options() - ->mutable_model_asset(), - is_copy); + if (!hand_detector_graph_options->base_options().has_model_asset()) { + ASSIGN_OR_RETURN(const auto hand_detector_file, + resources.GetModelFile(kHandDetectorTFLiteName)); + SetExternalFile(hand_detector_file, + hand_detector_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + } hand_detector_graph_options->mutable_base_options() ->mutable_acceleration() ->CopyFrom(options->base_options().acceleration()); hand_detector_graph_options->mutable_base_options()->set_use_stream_mode( options->base_options().use_stream_mode()); - ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, - resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); auto* hand_landmarks_detector_graph_options = options->mutable_hand_landmarks_detector_graph_options(); - SetExternalFile(hand_landmarks_detector_file, - hand_landmarks_detector_graph_options->mutable_base_options() - ->mutable_model_asset(), - is_copy); + if (!hand_landmarks_detector_graph_options->base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, + resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); + SetExternalFile( + hand_landmarks_detector_file, + hand_landmarks_detector_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + } hand_landmarks_detector_graph_options->mutable_base_options() ->mutable_acceleration() ->CopyFrom(options->base_options().acceleration()); @@ -119,7 +127,6 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->set_use_stream_mode(options->base_options().use_stream_mode()); return absl::OkStatus(); } - } // namespace // A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand @@ -219,12 +226,15 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); } + Stream image_in = graph.In(kImageTag).Cast(); + std::optional> norm_rect_in; + if (HasInput(sc->OriginalNode(), kNormRectTag)) { + norm_rect_in = graph.In(kNormRectTag).Cast(); + } ASSIGN_OR_RETURN( auto hand_landmarker_outputs, - BuildHandLandmarkerGraph( - sc->Options(), - graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + BuildHandLandmarkerGraph(sc->Options(), + image_in, norm_rect_in, graph)); hand_landmarker_outputs.landmark_lists >> graph[Output>(kLandmarksTag)]; hand_landmarker_outputs.world_landmark_lists >> @@ -262,8 +272,8 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { // image_in: (mediapipe::Image) stream to run hand landmark detection on. // graph: the mediapipe graph instance to be updated. absl::StatusOr BuildHandLandmarkerGraph( - const HandLandmarkerGraphOptions& tasks_options, Source image_in, - Source norm_rect_in, Graph& graph) { + const HandLandmarkerGraphOptions& tasks_options, Stream image_in, + std::optional> norm_rect_in, Graph& graph) { const int max_num_hands = tasks_options.hand_detector_graph_options().num_hands(); @@ -293,10 +303,15 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { // track the hands from the last frame. auto image_for_hand_detector = DisallowIf(image_in, has_enough_hands, graph); - auto norm_rect_in_for_hand_detector = - DisallowIf(norm_rect_in, has_enough_hands, graph); + std::optional> norm_rect_in_for_hand_detector; + if (norm_rect_in) { + norm_rect_in_for_hand_detector = + DisallowIf(norm_rect_in.value(), has_enough_hands, graph); + } image_for_hand_detector >> hand_detector.In("IMAGE"); - norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT"); + if (norm_rect_in_for_hand_detector) { + norm_rect_in_for_hand_detector.value() >> hand_detector.In("NORM_RECT"); + } auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); auto& hand_association = graph.AddNode("HandAssociationCalculator"); hand_association.GetOptions() @@ -313,7 +328,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { // series, and we don't want to enable the tracking and hand associations // between input images. Always use the hand detector graph. image_in >> hand_detector.In("IMAGE"); - norm_rect_in >> hand_detector.In("NORM_RECT"); + if (norm_rect_in) { + norm_rect_in.value() >> hand_detector.In("NORM_RECT"); + } auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); hand_rects_from_hand_detector >> clip_hand_rects.In(""); } diff --git a/mediapipe/tasks/testdata/vision/hand_landmarker.task b/mediapipe/tasks/testdata/vision/hand_landmarker.task index 1ae9f7f6b..748b2f013 100644 Binary files a/mediapipe/tasks/testdata/vision/hand_landmarker.task and b/mediapipe/tasks/testdata/vision/hand_landmarker.task differ diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 7122c6771..f446b3728 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -306,8 +306,8 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_recognizer_task", - sha256 = "a966b1d4e774e0423c19c8aa71f070e5a72fe7a03c2663dd2f3cb0b0095ee3e1", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_recognizer.task?generation=1668100501451433"], + sha256 = "d48562f535fd4ecd3cfea739d9663dd818eeaf6a8afb1b5e6f8f4747661f73d9", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_recognizer.task?generation=1677051715043311"], ) http_file( @@ -342,8 +342,8 @@ def external_files(): http_file( name = "com_google_mediapipe_hand_landmarker_task", - sha256 = "2ed44f10872e87a5834b9b1130fb9ada30e107af2c6fcc4562ad788aca4e7bc4", - urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmarker.task?generation=1666153732577904"], + sha256 = "32d1eab97e80a9a20edb29231e15301ce65abfd0fa9d41cf1757e0ecc8078a4e", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmarker.task?generation=1677051718270846"], ) http_file(