Internal change

PiperOrigin-RevId: 512111461
This commit is contained in:
MediaPipe Team 2023-02-24 10:54:46 -08:00 committed by Copybara-Service
parent 9054ff7283
commit 17466fb7f1
5 changed files with 68 additions and 41 deletions

View File

@ -257,19 +257,28 @@ class HandDetectorGraph : public core::ModelTaskGraph {
preprocessed_tensors >> inference.In("TENSORS"); preprocessed_tensors >> inference.In("TENSORS");
auto model_output_tensors = inference.Out("TENSORS"); auto model_output_tensors = inference.Out("TENSORS");
// TODO: support hand detection metadata.
bool has_metadata = false;
// Generates a single side packet containing a vector of SSD anchors. // Generates a single side packet containing a vector of SSD anchors.
auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
ConfigureSsdAnchorsCalculator( auto& ssd_anchor_options =
&ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>()); ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>();
if (!has_metadata) {
ConfigureSsdAnchorsCalculator(&ssd_anchor_options);
}
auto anchors = ssd_anchor.SideOut(""); auto anchors = ssd_anchor.SideOut("");
// Converts output tensors to Detections. // Converts output tensors to Detections.
auto& tensors_to_detections = auto& tensors_to_detections =
graph.AddNode("TensorsToDetectionsCalculator"); graph.AddNode("TensorsToDetectionsCalculator");
ConfigureTensorsToDetectionsCalculator( if (!has_metadata) {
subgraph_options, ConfigureTensorsToDetectionsCalculator(
&tensors_to_detections subgraph_options,
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()); &tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
}
model_output_tensors >> tensors_to_detections.In("TENSORS"); model_output_tensors >> tensors_to_detections.In("TENSORS");
anchors >> tensors_to_detections.SideIn("ANCHORS"); anchors >> tensors_to_detections.SideIn("ANCHORS");
auto detections = tensors_to_detections.Out("DETECTIONS"); auto detections = tensors_to_detections.Out("DETECTIONS");

View File

@ -148,6 +148,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_landmarks_deduplication_calculator", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_landmarks_deduplication_calculator",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <memory> #include <memory>
#include <optional>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -41,6 +42,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -53,7 +55,7 @@ using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input; using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::components::utils::DisallowIf; using ::mediapipe::tasks::components::utils::DisallowIf;
using ::mediapipe::tasks::core::ModelAssetBundleResources; using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::metadata::SetExternalFile; using ::mediapipe::tasks::metadata::SetExternalFile;
@ -78,40 +80,46 @@ constexpr char kHandLandmarksDetectorTFLiteName[] =
"hand_landmarks_detector.tflite"; "hand_landmarks_detector.tflite";
struct HandLandmarkerOutputs { struct HandLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> landmark_lists; Stream<std::vector<NormalizedLandmarkList>> landmark_lists;
Source<std::vector<LandmarkList>> world_landmark_lists; Stream<std::vector<LandmarkList>> world_landmark_lists;
Source<std::vector<NormalizedRect>> hand_rects_next_frame; Stream<std::vector<NormalizedRect>> hand_rects_next_frame;
Source<std::vector<ClassificationList>> handednesses; Stream<std::vector<ClassificationList>> handednesses;
Source<std::vector<NormalizedRect>> palm_rects; Stream<std::vector<NormalizedRect>> palm_rects;
Source<std::vector<Detection>> palm_detections; Stream<std::vector<Detection>> palm_detections;
Source<Image> image; Stream<Image> image;
}; };
// Sets the base options in the sub tasks. // Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
HandLandmarkerGraphOptions* options, HandLandmarkerGraphOptions* options,
bool is_copy) { bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName));
auto* hand_detector_graph_options = auto* hand_detector_graph_options =
options->mutable_hand_detector_graph_options(); options->mutable_hand_detector_graph_options();
SetExternalFile(hand_detector_file, if (!hand_detector_graph_options->base_options().has_model_asset()) {
hand_detector_graph_options->mutable_base_options() ASSIGN_OR_RETURN(const auto hand_detector_file,
->mutable_model_asset(), resources.GetModelFile(kHandDetectorTFLiteName));
is_copy); SetExternalFile(hand_detector_file,
hand_detector_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
}
hand_detector_graph_options->mutable_base_options() hand_detector_graph_options->mutable_base_options()
->mutable_acceleration() ->mutable_acceleration()
->CopyFrom(options->base_options().acceleration()); ->CopyFrom(options->base_options().acceleration());
hand_detector_graph_options->mutable_base_options()->set_use_stream_mode( hand_detector_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode()); options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
auto* hand_landmarks_detector_graph_options = auto* hand_landmarks_detector_graph_options =
options->mutable_hand_landmarks_detector_graph_options(); options->mutable_hand_landmarks_detector_graph_options();
SetExternalFile(hand_landmarks_detector_file, if (!hand_landmarks_detector_graph_options->base_options()
hand_landmarks_detector_graph_options->mutable_base_options() .has_model_asset()) {
->mutable_model_asset(), ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
is_copy); resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
SetExternalFile(
hand_landmarks_detector_file,
hand_landmarks_detector_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
}
hand_landmarks_detector_graph_options->mutable_base_options() hand_landmarks_detector_graph_options->mutable_base_options()
->mutable_acceleration() ->mutable_acceleration()
->CopyFrom(options->base_options().acceleration()); ->CopyFrom(options->base_options().acceleration());
@ -119,7 +127,6 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());
return absl::OkStatus(); return absl::OkStatus();
} }
} // namespace } // namespace
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand // A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
@ -219,12 +226,15 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable())); .IsAvailable()));
} }
Stream<Image> image_in = graph.In(kImageTag).Cast<Image>();
std::optional<Stream<NormalizedRect>> norm_rect_in;
if (HasInput(sc->OriginalNode(), kNormRectTag)) {
norm_rect_in = graph.In(kNormRectTag).Cast<NormalizedRect>();
}
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto hand_landmarker_outputs, auto hand_landmarker_outputs,
BuildHandLandmarkerGraph( BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(),
sc->Options<HandLandmarkerGraphOptions>(), image_in, norm_rect_in, graph));
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_landmarker_outputs.landmark_lists >> hand_landmarker_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)]; graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_outputs.world_landmark_lists >> hand_landmarker_outputs.world_landmark_lists >>
@ -262,8 +272,8 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
// image_in: (mediapipe::Image) stream to run hand landmark detection on. // image_in: (mediapipe::Image) stream to run hand landmark detection on.
// graph: the mediapipe graph instance to be updated. // graph: the mediapipe graph instance to be updated.
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerGraph( absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerGraph(
const HandLandmarkerGraphOptions& tasks_options, Source<Image> image_in, const HandLandmarkerGraphOptions& tasks_options, Stream<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { std::optional<Stream<NormalizedRect>> norm_rect_in, Graph& graph) {
const int max_num_hands = const int max_num_hands =
tasks_options.hand_detector_graph_options().num_hands(); tasks_options.hand_detector_graph_options().num_hands();
@ -293,10 +303,15 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
// track the hands from the last frame. // track the hands from the last frame.
auto image_for_hand_detector = auto image_for_hand_detector =
DisallowIf(image_in, has_enough_hands, graph); DisallowIf(image_in, has_enough_hands, graph);
auto norm_rect_in_for_hand_detector = std::optional<Stream<NormalizedRect>> norm_rect_in_for_hand_detector;
DisallowIf(norm_rect_in, has_enough_hands, graph); if (norm_rect_in) {
norm_rect_in_for_hand_detector =
DisallowIf(norm_rect_in.value(), has_enough_hands, graph);
}
image_for_hand_detector >> hand_detector.In("IMAGE"); image_for_hand_detector >> hand_detector.In("IMAGE");
norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT"); if (norm_rect_in_for_hand_detector) {
norm_rect_in_for_hand_detector.value() >> hand_detector.In("NORM_RECT");
}
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
auto& hand_association = graph.AddNode("HandAssociationCalculator"); auto& hand_association = graph.AddNode("HandAssociationCalculator");
hand_association.GetOptions<HandAssociationCalculatorOptions>() hand_association.GetOptions<HandAssociationCalculatorOptions>()
@ -313,7 +328,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
// series, and we don't want to enable the tracking and hand associations // series, and we don't want to enable the tracking and hand associations
// between input images. Always use the hand detector graph. // between input images. Always use the hand detector graph.
image_in >> hand_detector.In("IMAGE"); image_in >> hand_detector.In("IMAGE");
norm_rect_in >> hand_detector.In("NORM_RECT"); if (norm_rect_in) {
norm_rect_in.value() >> hand_detector.In("NORM_RECT");
}
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
hand_rects_from_hand_detector >> clip_hand_rects.In(""); hand_rects_from_hand_detector >> clip_hand_rects.In("");
} }

View File

@ -306,8 +306,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_gesture_recognizer_task", name = "com_google_mediapipe_gesture_recognizer_task",
sha256 = "a966b1d4e774e0423c19c8aa71f070e5a72fe7a03c2663dd2f3cb0b0095ee3e1", sha256 = "d48562f535fd4ecd3cfea739d9663dd818eeaf6a8afb1b5e6f8f4747661f73d9",
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_recognizer.task?generation=1668100501451433"], urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_recognizer.task?generation=1677051715043311"],
) )
http_file( http_file(
@ -342,8 +342,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_hand_landmarker_task", name = "com_google_mediapipe_hand_landmarker_task",
sha256 = "2ed44f10872e87a5834b9b1130fb9ada30e107af2c6fcc4562ad788aca4e7bc4", sha256 = "32d1eab97e80a9a20edb29231e15301ce65abfd0fa9d41cf1757e0ecc8078a4e",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmarker.task?generation=1666153732577904"], urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmarker.task?generation=1677051718270846"],
) )
http_file( http_file(