Add mediapipe tasks face detector graph
PiperOrigin-RevId: 504078951
This commit is contained in:
		
							parent
							
								
									ccd1461add
								
							
						
					
					
						commit
						873d7181bf
					
				
							
								
								
									
										61
									
								
								mediapipe/tasks/cc/vision/face_detector/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								mediapipe/tasks/cc/vision/face_detector/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,61 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
package(default_visibility = [
 | 
			
		||||
    # "//mediapipe/tasks:internal",
 | 
			
		||||
    "//visibility:public",
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "face_detector_graph",
 | 
			
		||||
    srcs = ["face_detector_graph.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/core:clip_vector_size_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/tensor:inference_calculator",
 | 
			
		||||
        "//mediapipe/calculators/tensor:tensors_to_detections_calculator",
 | 
			
		||||
        "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/tflite:ssd_anchors_calculator",
 | 
			
		||||
        "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/util:detection_label_id_to_text_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/util:detection_projection_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:detections_to_rects_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/util:non_max_suppression_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/util:rect_transformation_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/formats:detection_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/tasks/cc:common",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_resources",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:utils",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										208
									
								
								mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										208
									
								
								mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,208 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/formats/detection.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/common.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_resources.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/utils.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
namespace vision {
 | 
			
		||||
namespace face_detector {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::NormalizedRect;
 | 
			
		||||
using ::mediapipe::Tensor;
 | 
			
		||||
using ::mediapipe::api2::Input;
 | 
			
		||||
using ::mediapipe::api2::Output;
 | 
			
		||||
using ::mediapipe::api2::builder::Graph;
 | 
			
		||||
using ::mediapipe::api2::builder::Source;
 | 
			
		||||
using ::mediapipe::tasks::vision::face_detector::proto::
 | 
			
		||||
    FaceDetectorGraphOptions;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
constexpr char kImageTag[] = "IMAGE";
 | 
			
		||||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
			
		||||
constexpr char kDetectionsTag[] = "DETECTIONS";
 | 
			
		||||
 | 
			
		||||
void ConfigureSsdAnchorsCalculator(
 | 
			
		||||
    mediapipe::SsdAnchorsCalculatorOptions* options) {
 | 
			
		||||
  // TODO config SSD anchors parameters from metadata.
 | 
			
		||||
  options->set_num_layers(1);
 | 
			
		||||
  options->set_min_scale(0.1484375);
 | 
			
		||||
  options->set_max_scale(0.75);
 | 
			
		||||
  options->set_input_size_height(192);
 | 
			
		||||
  options->set_input_size_width(192);
 | 
			
		||||
  options->set_anchor_offset_x(0.5);
 | 
			
		||||
  options->set_anchor_offset_y(0.5);
 | 
			
		||||
  options->add_strides(4);
 | 
			
		||||
  options->add_aspect_ratios(1.0);
 | 
			
		||||
  options->set_fixed_anchor_size(true);
 | 
			
		||||
  options->set_interpolated_scale_aspect_ratio(0.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ConfigureTensorsToDetectionsCalculator(
 | 
			
		||||
    const FaceDetectorGraphOptions& tasks_options,
 | 
			
		||||
    mediapipe::TensorsToDetectionsCalculatorOptions* options) {
 | 
			
		||||
  // TODO use metadata to configure these fields.
 | 
			
		||||
  options->set_num_classes(1);
 | 
			
		||||
  options->set_num_boxes(2304);
 | 
			
		||||
  options->set_num_coords(16);
 | 
			
		||||
  options->set_box_coord_offset(0);
 | 
			
		||||
  options->set_keypoint_coord_offset(4);
 | 
			
		||||
  options->set_num_keypoints(6);
 | 
			
		||||
  options->set_num_values_per_keypoint(2);
 | 
			
		||||
  options->set_sigmoid_score(true);
 | 
			
		||||
  options->set_score_clipping_thresh(100.0);
 | 
			
		||||
  options->set_reverse_output_order(true);
 | 
			
		||||
  options->set_min_score_thresh(tasks_options.min_detection_confidence());
 | 
			
		||||
  options->set_x_scale(192.0);
 | 
			
		||||
  options->set_y_scale(192.0);
 | 
			
		||||
  options->set_w_scale(192.0);
 | 
			
		||||
  options->set_h_scale(192.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ConfigureNonMaxSuppressionCalculator(
 | 
			
		||||
    const FaceDetectorGraphOptions& tasks_options,
 | 
			
		||||
    mediapipe::NonMaxSuppressionCalculatorOptions* options) {
 | 
			
		||||
  options->set_min_suppression_threshold(
 | 
			
		||||
      tasks_options.min_suppression_threshold());
 | 
			
		||||
  options->set_overlap_type(
 | 
			
		||||
      mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION);
 | 
			
		||||
  options->set_algorithm(
 | 
			
		||||
      mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
class FaceDetectorGraph : public core::ModelTaskGraph {
 | 
			
		||||
 public:
 | 
			
		||||
  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
			
		||||
      SubgraphContext* sc) override {
 | 
			
		||||
    ASSIGN_OR_RETURN(const auto* model_resources,
 | 
			
		||||
                     CreateModelResources<FaceDetectorGraphOptions>(sc));
 | 
			
		||||
    Graph graph;
 | 
			
		||||
    ASSIGN_OR_RETURN(auto face_detections,
 | 
			
		||||
                     BuildFaceDetectionSubgraph(
 | 
			
		||||
                         sc->Options<FaceDetectorGraphOptions>(),
 | 
			
		||||
                         *model_resources, graph[Input<Image>(kImageTag)],
 | 
			
		||||
                         graph[Input<NormalizedRect>(kNormRectTag)], graph));
 | 
			
		||||
    face_detections >> graph[Output<std::vector<Detection>>(kDetectionsTag)];
 | 
			
		||||
    return graph.GetConfig();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  absl::StatusOr<Source<std::vector<Detection>>> BuildFaceDetectionSubgraph(
 | 
			
		||||
      const FaceDetectorGraphOptions& subgraph_options,
 | 
			
		||||
      const core::ModelResources& model_resources, Source<Image> image_in,
 | 
			
		||||
      Source<NormalizedRect> norm_rect_in, Graph& graph) {
 | 
			
		||||
    // Image preprocessing subgraph to convert image to tensor for the tflite
 | 
			
		||||
    // model.
 | 
			
		||||
    auto& preprocessing = graph.AddNode(
 | 
			
		||||
        "mediapipe.tasks.components.processors.ImagePreprocessingGraph");
 | 
			
		||||
    bool use_gpu =
 | 
			
		||||
        components::processors::DetermineImagePreprocessingGpuBackend(
 | 
			
		||||
            subgraph_options.base_options().acceleration());
 | 
			
		||||
    MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
 | 
			
		||||
        model_resources, use_gpu,
 | 
			
		||||
        &preprocessing.GetOptions<
 | 
			
		||||
            components::processors::proto::ImagePreprocessingGraphOptions>()));
 | 
			
		||||
    auto& image_to_tensor_options =
 | 
			
		||||
        *preprocessing
 | 
			
		||||
             .GetOptions<components::processors::proto::
 | 
			
		||||
                             ImagePreprocessingGraphOptions>()
 | 
			
		||||
             .mutable_image_to_tensor_options();
 | 
			
		||||
    image_to_tensor_options.set_keep_aspect_ratio(true);
 | 
			
		||||
    image_to_tensor_options.set_border_mode(
 | 
			
		||||
        mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
 | 
			
		||||
    image_in >> preprocessing.In("IMAGE");
 | 
			
		||||
    norm_rect_in >> preprocessing.In("NORM_RECT");
 | 
			
		||||
    auto preprocessed_tensors = preprocessing.Out("TENSORS");
 | 
			
		||||
    auto matrix = preprocessing.Out("MATRIX");
 | 
			
		||||
 | 
			
		||||
    // Face detection model inferece.
 | 
			
		||||
    auto& inference = AddInference(
 | 
			
		||||
        model_resources, subgraph_options.base_options().acceleration(), graph);
 | 
			
		||||
    preprocessed_tensors >> inference.In("TENSORS");
 | 
			
		||||
    auto model_output_tensors =
 | 
			
		||||
        inference.Out("TENSORS").Cast<std::vector<Tensor>>();
 | 
			
		||||
 | 
			
		||||
    // Generates a single side packet containing a vector of SSD anchors.
 | 
			
		||||
    auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
 | 
			
		||||
    ConfigureSsdAnchorsCalculator(
 | 
			
		||||
        &ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>());
 | 
			
		||||
    auto anchors = ssd_anchor.SideOut("");
 | 
			
		||||
 | 
			
		||||
    // Converts output tensors to Detections.
 | 
			
		||||
    auto& tensors_to_detections =
 | 
			
		||||
        graph.AddNode("TensorsToDetectionsCalculator");
 | 
			
		||||
    ConfigureTensorsToDetectionsCalculator(
 | 
			
		||||
        subgraph_options,
 | 
			
		||||
        &tensors_to_detections
 | 
			
		||||
             .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
 | 
			
		||||
    model_output_tensors >> tensors_to_detections.In("TENSORS");
 | 
			
		||||
    anchors >> tensors_to_detections.SideIn("ANCHORS");
 | 
			
		||||
    auto detections = tensors_to_detections.Out("DETECTIONS");
 | 
			
		||||
 | 
			
		||||
    // Non maximum suppression removes redundant face detections.
 | 
			
		||||
    auto& non_maximum_suppression =
 | 
			
		||||
        graph.AddNode("NonMaxSuppressionCalculator");
 | 
			
		||||
    ConfigureNonMaxSuppressionCalculator(
 | 
			
		||||
        subgraph_options,
 | 
			
		||||
        &non_maximum_suppression
 | 
			
		||||
             .GetOptions<mediapipe::NonMaxSuppressionCalculatorOptions>());
 | 
			
		||||
    detections >> non_maximum_suppression.In("");
 | 
			
		||||
    auto nms_detections = non_maximum_suppression.Out("");
 | 
			
		||||
 | 
			
		||||
    // Projects detections back into the input image coordinates system.
 | 
			
		||||
    auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
 | 
			
		||||
    nms_detections >> detection_projection.In("DETECTIONS");
 | 
			
		||||
    matrix >> detection_projection.In("PROJECTION_MATRIX");
 | 
			
		||||
    auto face_detections =
 | 
			
		||||
        detection_projection[Output<std::vector<Detection>>("DETECTIONS")];
 | 
			
		||||
 | 
			
		||||
    return {face_detections};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
REGISTER_MEDIAPIPE_GRAPH(
 | 
			
		||||
    ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph);
 | 
			
		||||
 | 
			
		||||
}  // namespace face_detector
 | 
			
		||||
}  // namespace vision
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,183 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/flags/flag.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/str_format.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/deps/file_path.h"
 | 
			
		||||
#include "mediapipe/framework/formats/detection.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
#include "mediapipe/framework/port/file_helpers.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_resources.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_runner.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
namespace vision {
 | 
			
		||||
namespace face_detector {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::file::Defaults;
 | 
			
		||||
using ::file::GetTextProto;
 | 
			
		||||
using ::mediapipe::NormalizedRect;
 | 
			
		||||
using ::mediapipe::api2::Input;
 | 
			
		||||
using ::mediapipe::api2::Output;
 | 
			
		||||
using ::mediapipe::api2::builder::Graph;
 | 
			
		||||
using ::mediapipe::api2::builder::Source;
 | 
			
		||||
using ::mediapipe::file::JoinPath;
 | 
			
		||||
using ::mediapipe::tasks::core::TaskRunner;
 | 
			
		||||
using ::mediapipe::tasks::vision::DecodeImageFromFile;
 | 
			
		||||
using ::mediapipe::tasks::vision::face_detector::proto::
 | 
			
		||||
    FaceDetectorGraphOptions;
 | 
			
		||||
using ::testing::EqualsProto;
 | 
			
		||||
using ::testing::Pointwise;
 | 
			
		||||
using ::testing::TestParamInfo;
 | 
			
		||||
using ::testing::TestWithParam;
 | 
			
		||||
using ::testing::Values;
 | 
			
		||||
using ::testing::proto::Approximately;
 | 
			
		||||
using ::testing::proto::Partially;
 | 
			
		||||
 | 
			
		||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
 | 
			
		||||
constexpr char kFullRangeBlazeFaceModel[] = "face_detection_full_range.tflite";
 | 
			
		||||
constexpr char kFullRangeSparseBlazeFaceModel[] =
 | 
			
		||||
    "face_detection_full_range_sparse.tflite";
 | 
			
		||||
constexpr char kPortraitImage[] = "portrait.jpg";
 | 
			
		||||
constexpr char kPortraitExpectedDetection[] =
 | 
			
		||||
    "portrait_expected_detection.pbtxt";
 | 
			
		||||
 | 
			
		||||
constexpr char kImageTag[] = "IMAGE";
 | 
			
		||||
constexpr char kImageName[] = "image";
 | 
			
		||||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
			
		||||
constexpr char kNormRectName[] = "norm_rect";
 | 
			
		||||
constexpr char kDetectionsTag[] = "DETECTIONS";
 | 
			
		||||
constexpr char kDetectionsName[] = "detections";
 | 
			
		||||
 | 
			
		||||
constexpr float kFaceDetectionMaxDiff = 0.01;
 | 
			
		||||
 | 
			
		||||
// Helper function to create a TaskRunner.
 | 
			
		||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
 | 
			
		||||
    absl::string_view model_name) {
 | 
			
		||||
  Graph graph;
 | 
			
		||||
 | 
			
		||||
  auto& face_detector_graph =
 | 
			
		||||
      graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
 | 
			
		||||
 | 
			
		||||
  auto options = std::make_unique<FaceDetectorGraphOptions>();
 | 
			
		||||
  options->mutable_base_options()->mutable_model_asset()->set_file_name(
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, model_name));
 | 
			
		||||
  options->set_min_detection_confidence(0.6);
 | 
			
		||||
  options->set_min_suppression_threshold(0.3);
 | 
			
		||||
  face_detector_graph.GetOptions<FaceDetectorGraphOptions>().Swap(
 | 
			
		||||
      options.get());
 | 
			
		||||
 | 
			
		||||
  graph[Input<Image>(kImageTag)].SetName(kImageName) >>
 | 
			
		||||
      face_detector_graph.In(kImageTag);
 | 
			
		||||
  graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
 | 
			
		||||
      face_detector_graph.In(kNormRectTag);
 | 
			
		||||
 | 
			
		||||
  face_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >>
 | 
			
		||||
      graph[Output<std::vector<Detection>>(kDetectionsTag)];
 | 
			
		||||
 | 
			
		||||
  return TaskRunner::Create(
 | 
			
		||||
      graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Detection GetExpectedFaceDetectionResult(absl::string_view file_name) {
 | 
			
		||||
  Detection detection;
 | 
			
		||||
  CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name),
 | 
			
		||||
                        &detection, Defaults()))
 | 
			
		||||
      << "Expected face detection result does not exist.";
 | 
			
		||||
  return detection;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct TestParams {
 | 
			
		||||
  // The name of this test, for convenience when displaying test results.
 | 
			
		||||
  std::string test_name;
 | 
			
		||||
  // The filename of face landmark detection model.
 | 
			
		||||
  std::string face_detection_model_name;
 | 
			
		||||
  // The filename of test image.
 | 
			
		||||
  std::string test_image_name;
 | 
			
		||||
  // Expected face detection results.
 | 
			
		||||
  std::vector<Detection> expected_result;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class FaceDetectorGraphTest : public testing::TestWithParam<TestParams> {};
 | 
			
		||||
 | 
			
		||||
TEST_P(FaceDetectorGraphTest, Succeed) {
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
 | 
			
		||||
                                                GetParam().test_image_name)));
 | 
			
		||||
  NormalizedRect input_norm_rect;
 | 
			
		||||
  input_norm_rect.set_x_center(0.5);
 | 
			
		||||
  input_norm_rect.set_y_center(0.5);
 | 
			
		||||
  input_norm_rect.set_width(1.0);
 | 
			
		||||
  input_norm_rect.set_height(1.0);
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name));
 | 
			
		||||
  auto output_packets = task_runner->Process(
 | 
			
		||||
      {{kImageName, MakePacket<Image>(std::move(image))},
 | 
			
		||||
       {kNormRectName,
 | 
			
		||||
        MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
 | 
			
		||||
  MP_ASSERT_OK(output_packets);
 | 
			
		||||
  const std::vector<Detection>& face_detections =
 | 
			
		||||
      (*output_packets)[kDetectionsName].Get<std::vector<Detection>>();
 | 
			
		||||
  EXPECT_THAT(face_detections, Pointwise(Approximately(Partially(EqualsProto()),
 | 
			
		||||
                                                       kFaceDetectionMaxDiff),
 | 
			
		||||
                                         GetParam().expected_result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(
 | 
			
		||||
    FaceDetectorGraphTest, FaceDetectorGraphTest,
 | 
			
		||||
    Values(TestParams{.test_name = "FullRange",
 | 
			
		||||
                      .face_detection_model_name = kFullRangeBlazeFaceModel,
 | 
			
		||||
                      .test_image_name = kPortraitImage,
 | 
			
		||||
                      .expected_result = {GetExpectedFaceDetectionResult(
 | 
			
		||||
                          kPortraitExpectedDetection)}},
 | 
			
		||||
           TestParams{
 | 
			
		||||
               .test_name = "FullRangeSparse",
 | 
			
		||||
               .face_detection_model_name = kFullRangeSparseBlazeFaceModel,
 | 
			
		||||
               .test_image_name = kPortraitImage,
 | 
			
		||||
               .expected_result = {GetExpectedFaceDetectionResult(
 | 
			
		||||
                   kPortraitExpectedDetection)}}),
 | 
			
		||||
    [](const TestParamInfo<FaceDetectorGraphTest::ParamType>& info) {
 | 
			
		||||
      return info.param.test_name;
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace face_detector
 | 
			
		||||
}  // namespace vision
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
							
								
								
									
										31
									
								
								mediapipe/tasks/cc/vision/face_detector/proto/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								mediapipe/tasks/cc/vision/face_detector/proto/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,31 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
 | 
			
		||||
 | 
			
		||||
package(default_visibility = [
 | 
			
		||||
    "//mediapipe/tasks:internal",
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
mediapipe_proto_library(
 | 
			
		||||
    name = "face_detector_graph_options_proto",
 | 
			
		||||
    srcs = ["face_detector_graph_options.proto"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework:calculator_options_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,42 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
syntax = "proto2";
 | 
			
		||||
 | 
			
		||||
package mediapipe.tasks.vision.face_detector.proto;
 | 
			
		||||
 | 
			
		||||
import "mediapipe/framework/calculator.proto";
 | 
			
		||||
import "mediapipe/framework/calculator_options.proto";
 | 
			
		||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
 | 
			
		||||
 | 
			
		||||
option java_package = "com.google.mediapipe.tasks.vision.facedetector.proto";
 | 
			
		||||
option java_outer_classname = "FaceDetectorGraphOptionsProto";
 | 
			
		||||
 | 
			
		||||
message FaceDetectorGraphOptions {
 | 
			
		||||
  extend mediapipe.CalculatorOptions {
 | 
			
		||||
    optional FaceDetectorGraphOptions ext = 502141897;
 | 
			
		||||
  }
 | 
			
		||||
  // Base options for configuring Task library, such as specifying the TfLite
 | 
			
		||||
  // model file with metadata, accelerator options, etc.
 | 
			
		||||
  optional core.proto.BaseOptions base_options = 1;
 | 
			
		||||
 | 
			
		||||
  // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
 | 
			
		||||
  // successfully detecting a face in the image.
 | 
			
		||||
  optional float min_detection_confidence = 2 [default = 0.5];
 | 
			
		||||
 | 
			
		||||
  // IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered
 | 
			
		||||
  // duplicate detetions.
 | 
			
		||||
  optional float min_suppression_threshold = 3 [default = 0.5];
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										8
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -37,6 +37,8 @@ mediapipe_files(srcs = [
 | 
			
		|||
    "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
 | 
			
		||||
    "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
 | 
			
		||||
    "deeplabv3.tflite",
 | 
			
		||||
    "face_detection_full_range.tflite",
 | 
			
		||||
    "face_detection_full_range_sparse.tflite",
 | 
			
		||||
    "fist.jpg",
 | 
			
		||||
    "fist.png",
 | 
			
		||||
    "hand_landmark_full.tflite",
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +60,7 @@ mediapipe_files(srcs = [
 | 
			
		|||
    "palm_detection_full.tflite",
 | 
			
		||||
    "pointing_up.jpg",
 | 
			
		||||
    "pointing_up_rotated.jpg",
 | 
			
		||||
    "portrait.jpg",
 | 
			
		||||
    "right_hands.jpg",
 | 
			
		||||
    "right_hands_rotated.jpg",
 | 
			
		||||
    "segmentation_golden_rotation0.png",
 | 
			
		||||
| 
						 | 
				
			
			@ -79,6 +82,7 @@ exports_files(
 | 
			
		|||
        "expected_right_down_hand_landmarks.prototxt",
 | 
			
		||||
        "expected_right_up_hand_landmarks.prototxt",
 | 
			
		||||
        "gesture_recognizer.task",
 | 
			
		||||
        "portrait_expected_detection.pbtxt",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -106,6 +110,7 @@ filegroup(
 | 
			
		|||
        "multi_objects_rotated.jpg",
 | 
			
		||||
        "pointing_up.jpg",
 | 
			
		||||
        "pointing_up_rotated.jpg",
 | 
			
		||||
        "portrait.jpg",
 | 
			
		||||
        "right_hands.jpg",
 | 
			
		||||
        "right_hands_rotated.jpg",
 | 
			
		||||
        "segmentation_golden_rotation0.png",
 | 
			
		||||
| 
						 | 
				
			
			@ -129,6 +134,8 @@ filegroup(
 | 
			
		|||
        "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
 | 
			
		||||
        "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
 | 
			
		||||
        "deeplabv3.tflite",
 | 
			
		||||
        "face_detection_full_range.tflite",
 | 
			
		||||
        "face_detection_full_range_sparse.tflite",
 | 
			
		||||
        "hand_landmark_full.tflite",
 | 
			
		||||
        "hand_landmark_lite.tflite",
 | 
			
		||||
        "hand_landmarker.task",
 | 
			
		||||
| 
						 | 
				
			
			@ -161,6 +168,7 @@ filegroup(
 | 
			
		|||
        "hand_detector_result_two_hands.pbtxt",
 | 
			
		||||
        "pointing_up_landmarks.pbtxt",
 | 
			
		||||
        "pointing_up_rotated_landmarks.pbtxt",
 | 
			
		||||
        "portrait_expected_detection.pbtxt",
 | 
			
		||||
        "thumb_up_landmarks.pbtxt",
 | 
			
		||||
        "thumb_up_rotated_landmarks.pbtxt",
 | 
			
		||||
        "victory_landmarks.pbtxt",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										35
									
								
								mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,35 @@
 | 
			
		|||
# proto-file: mediapipe/framework/formats/detection.proto
 | 
			
		||||
# proto-message: Detection
 | 
			
		||||
location_data {
 | 
			
		||||
  format: RELATIVE_BOUNDING_BOX
 | 
			
		||||
  relative_bounding_box {
 | 
			
		||||
    xmin: 0.35494408
 | 
			
		||||
    ymin: 0.1059662
 | 
			
		||||
    width: 0.28768203
 | 
			
		||||
    height: 0.23037356
 | 
			
		||||
  }
 | 
			
		||||
  relative_keypoints {
 | 
			
		||||
    x: 0.44416338
 | 
			
		||||
    y: 0.17643969
 | 
			
		||||
  }
 | 
			
		||||
  relative_keypoints {
 | 
			
		||||
    x: 0.55514044
 | 
			
		||||
    y: 0.17731678
 | 
			
		||||
  }
 | 
			
		||||
  relative_keypoints {
 | 
			
		||||
    x: 0.5046702
 | 
			
		||||
    y: 0.2265771
 | 
			
		||||
  }
 | 
			
		||||
  relative_keypoints {
 | 
			
		||||
    x: 0.50227845
 | 
			
		||||
    y: 0.2719954
 | 
			
		||||
  }
 | 
			
		||||
  relative_keypoints {
 | 
			
		||||
    x: 0.37245658
 | 
			
		||||
    y: 0.20143759
 | 
			
		||||
  }
 | 
			
		||||
  relative_keypoints {
 | 
			
		||||
    x: 0.6084143
 | 
			
		||||
    y: 0.20409837
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										20
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										20
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -240,14 +240,14 @@ def external_files():
 | 
			
		|||
 | 
			
		||||
    http_file(
 | 
			
		||||
        name = "com_google_mediapipe_face_detection_full_range_sparse_tflite",
 | 
			
		||||
        sha256 = "671dd2f9ed11a78436fc21cc42357a803dfc6f73e9fb86541be942d5716c2dce",
 | 
			
		||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1661875739104017"],
 | 
			
		||||
        sha256 = "2c3728e6da56f21e21a320433396fb06d40d9088f2247c05e5635a688d45dfe1",
 | 
			
		||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1674261618323821"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    http_file(
 | 
			
		||||
        name = "com_google_mediapipe_face_detection_full_range_tflite",
 | 
			
		||||
        sha256 = "99bf9494d84f50acc6617d89873f71bf6635a841ea699c17cb3377f9507cfec3",
 | 
			
		||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1661875742733283"],
 | 
			
		||||
        sha256 = "3698b18f063835bc609069ef052228fbe86d9c9a6dc8dcb7c7c2d69aed2b181b",
 | 
			
		||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1674261620964007"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    http_file(
 | 
			
		||||
| 
						 | 
				
			
			@ -712,6 +712,18 @@ def external_files():
 | 
			
		|||
        urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    http_file(
 | 
			
		||||
        name = "com_google_mediapipe_portrait_expected_detection_pbtxt",
 | 
			
		||||
        sha256 = "bb54e08e87844ef14bb185d5cb808908eb6011bfa6db48bd22d9650f6fda338b",
 | 
			
		||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_detection.pbtxt?generation=1674261627835475"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    http_file(
 | 
			
		||||
        name = "com_google_mediapipe_portrait_jpg",
 | 
			
		||||
        sha256 = "a6f11efaa834706db23f275b6115058fa87fc7f14362681e6abe14e82749de3e",
 | 
			
		||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/portrait.jpg?generation=1674261630039907"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    http_file(
 | 
			
		||||
        name = "com_google_mediapipe_pose_detection_tflite",
 | 
			
		||||
        sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user