No public description
PiperOrigin-RevId: 583973946
This commit is contained in:
		
							parent
							
								
									5cd3037443
								
							
						
					
					
						commit
						d8fd986517
					
				| 
						 | 
				
			
			@ -12,6 +12,8 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
| 
						 | 
				
			
			@ -99,10 +101,14 @@ cc_library(
 | 
			
		|||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "image_preprocessing_graph",
 | 
			
		||||
    srcs = ["image_preprocessing_graph.cc"],
 | 
			
		||||
    hdrs = ["image_preprocessing_graph.h"],
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_resources",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/core:pass_through_calculator",
 | 
			
		||||
        "//mediapipe/calculators/image:image_clone_calculator",
 | 
			
		||||
| 
						 | 
				
			
			@ -120,10 +126,8 @@ cc_library(
 | 
			
		|||
        "//mediapipe/gpu:gpu_origin_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc:common",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_resources",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -271,8 +271,9 @@ class ImagePreprocessingGraph : public Subgraph {
 | 
			
		|||
    };
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
REGISTER_MEDIAPIPE_GRAPH(
 | 
			
		||||
    ::mediapipe::tasks::components::processors::ImagePreprocessingGraph);
 | 
			
		||||
    ::mediapipe::tasks::components::processors::ImagePreprocessingGraph)
 | 
			
		||||
 | 
			
		||||
}  // namespace processors
 | 
			
		||||
}  // namespace components
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -91,16 +91,16 @@ cc_library(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
 | 
			
		||||
# supports TFLite-in-GMSCore.
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "model_task_graph",
 | 
			
		||||
    srcs = ["model_task_graph.cc"],
 | 
			
		||||
    hdrs = ["model_task_graph.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":model_asset_bundle_resources",
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        ":model_resources",
 | 
			
		||||
        ":model_resources_cache",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":model_asset_bundle_resources",
 | 
			
		||||
        ":model_resources_calculator",
 | 
			
		||||
        "//mediapipe/calculators/tensor:inference_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			@ -224,19 +224,16 @@ cc_library_with_tflite(
 | 
			
		|||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test_with_tflite(
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "model_resources_calculator_test",
 | 
			
		||||
    srcs = ["model_resources_calculator_test.cc"],
 | 
			
		||||
    data = [
 | 
			
		||||
        "//mediapipe/tasks/testdata/core:test_models",
 | 
			
		||||
    ],
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":model_resources",
 | 
			
		||||
        ":model_resources_cache",
 | 
			
		||||
        ":model_resources_calculator",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite:test_util",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			@ -245,6 +242,7 @@ cc_test_with_tflite(
 | 
			
		|||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite:test_util",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -303,22 +301,26 @@ cc_test_with_tflite(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "base_task_api",
 | 
			
		||||
    hdrs = ["base_task_api.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        ":task_runner",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "task_api_factory",
 | 
			
		||||
    hdrs = ["task_api_factory.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        ":base_task_api",
 | 
			
		||||
        ":model_resources",
 | 
			
		||||
        ":task_runner",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":utils",
 | 
			
		||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:executor",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/common.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_resources.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +148,8 @@ class InferenceSubgraph : public Subgraph {
 | 
			
		|||
    return delegate;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph);
 | 
			
		||||
 | 
			
		||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph)
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
 | 
			
		||||
    SubgraphContext* sc) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,8 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
| 
						 | 
				
			
			@ -31,9 +33,15 @@ cc_library(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "base_vision_task_api",
 | 
			
		||||
    hdrs = ["base_vision_task_api.h"],
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        "//mediapipe/tasks/cc/core:base_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_runner",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image_processing_options",
 | 
			
		||||
        ":running_mode",
 | 
			
		||||
| 
						 | 
				
			
			@ -42,24 +50,22 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:rect",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:base_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_runner",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "vision_task_api_factory",
 | 
			
		||||
    hdrs = ["vision_task_api_factory.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        ":base_vision_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,6 +11,7 @@
 | 
			
		|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
 | 
			
		||||
 | 
			
		||||
package(default_visibility = [
 | 
			
		||||
    "//mediapipe/tasks:internal",
 | 
			
		||||
| 
						 | 
				
			
			@ -18,9 +19,15 @@ package(default_visibility = [
 | 
			
		|||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "face_detector_graph",
 | 
			
		||||
    srcs = ["face_detector_graph.cc"],
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_resources",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/core:clip_vector_size_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			@ -38,6 +45,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/util:rect_transformation_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/formats:detection_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			@ -45,36 +53,34 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/tasks/cc:common",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_resources",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:utils",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
cc_library_with_tflite(
 | 
			
		||||
    name = "face_detector",
 | 
			
		||||
    srcs = ["face_detector.cc"],
 | 
			
		||||
    hdrs = ["face_detector.h"],
 | 
			
		||||
    tflite_deps = [
 | 
			
		||||
        ":face_detector_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
 | 
			
		||||
    ],
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":face_detector_graph",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/formats:detection_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:detection_result",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:base_options",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:utils",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/core:image_processing_options",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/core:running_mode",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 | 
			
		|||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -26,6 +27,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/detection.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -213,14 +215,16 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::string GetImagePreprocessingGraphName() {
 | 
			
		||||
    return "mediapipe.tasks.components.processors.ImagePreprocessingGraph";
 | 
			
		||||
  }
 | 
			
		||||
  absl::StatusOr<FaceDetectionOuts> BuildFaceDetectionSubgraph(
 | 
			
		||||
      const FaceDetectorGraphOptions& subgraph_options,
 | 
			
		||||
      const core::ModelResources& model_resources, Source<Image> image_in,
 | 
			
		||||
      Source<NormalizedRect> norm_rect_in, Graph& graph) {
 | 
			
		||||
    // Image preprocessing subgraph to convert image to tensor for the tflite
 | 
			
		||||
    // model.
 | 
			
		||||
    auto& preprocessing = graph.AddNode(
 | 
			
		||||
        "mediapipe.tasks.components.processors.ImagePreprocessingGraph");
 | 
			
		||||
    auto& preprocessing = graph.AddNode(GetImagePreprocessingGraphName());
 | 
			
		||||
    bool use_gpu =
 | 
			
		||||
        components::processors::DetermineImagePreprocessingGpuBackend(
 | 
			
		||||
            subgraph_options.base_options().acceleration());
 | 
			
		||||
| 
						 | 
				
			
			@ -337,7 +341,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
 | 
			
		|||
};
 | 
			
		||||
 | 
			
		||||
REGISTER_MEDIAPIPE_GRAPH(
 | 
			
		||||
    ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph);
 | 
			
		||||
    ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph)
 | 
			
		||||
 | 
			
		||||
}  // namespace face_detector
 | 
			
		||||
}  // namespace vision
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,6 +23,7 @@ limitations under the License.
 | 
			
		|||
#include "absl/flags/flag.h"
 | 
			
		||||
#include "absl/log/absl_check.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/match.h"
 | 
			
		||||
#include "absl/strings/str_format.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -92,11 +93,10 @@ constexpr float kFaceDetectionMaxDiff = 0.01;
 | 
			
		|||
 | 
			
		||||
// Helper function to create a TaskRunner.
 | 
			
		||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
 | 
			
		||||
    absl::string_view model_name) {
 | 
			
		||||
    absl::string_view model_name, std::string graph_name) {
 | 
			
		||||
  Graph graph;
 | 
			
		||||
 | 
			
		||||
  auto& face_detector_graph =
 | 
			
		||||
      graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
 | 
			
		||||
  auto& face_detector_graph = graph.AddNode(graph_name);
 | 
			
		||||
 | 
			
		||||
  auto options = std::make_unique<FaceDetectorGraphOptions>();
 | 
			
		||||
  options->mutable_base_options()->mutable_model_asset()->set_file_name(
 | 
			
		||||
| 
						 | 
				
			
			@ -136,6 +136,8 @@ struct TestParams {
 | 
			
		|||
  std::string test_image_name;
 | 
			
		||||
  // Expected face detection results.
 | 
			
		||||
  std::vector<Detection> expected_result;
 | 
			
		||||
  // The name of the mediapipe graph to run.
 | 
			
		||||
  std::string graph_name;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class FaceDetectorGraphTest : public testing::TestWithParam<TestParams> {};
 | 
			
		||||
| 
						 | 
				
			
			@ -149,8 +151,9 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
 | 
			
		|||
  input_norm_rect.set_y_center(0.5);
 | 
			
		||||
  input_norm_rect.set_width(1.0);
 | 
			
		||||
  input_norm_rect.set_height(1.0);
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto task_runner,
 | 
			
		||||
                          CreateTaskRunner(GetParam().face_detection_model_name,
 | 
			
		||||
                                           GetParam().graph_name));
 | 
			
		||||
  auto output_packets = task_runner->Process(
 | 
			
		||||
      {{kImageName, MakePacket<Image>(std::move(image))},
 | 
			
		||||
       {kNormRectName,
 | 
			
		||||
| 
						 | 
				
			
			@ -165,11 +168,15 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
 | 
			
		|||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(
 | 
			
		||||
    FaceDetectorGraphTest, FaceDetectorGraphTest,
 | 
			
		||||
    Values(TestParams{.test_name = "ShortRange",
 | 
			
		||||
                      .face_detection_model_name = kShortRangeBlazeFaceModel,
 | 
			
		||||
                      .test_image_name = kPortraitImage,
 | 
			
		||||
                      .expected_result = {GetExpectedFaceDetectionResult(
 | 
			
		||||
                          kPortraitExpectedDetection)}}),
 | 
			
		||||
    Values(
 | 
			
		||||
        TestParams{
 | 
			
		||||
            .test_name = "ShortRange",
 | 
			
		||||
            .face_detection_model_name = kShortRangeBlazeFaceModel,
 | 
			
		||||
            .test_image_name = kPortraitImage,
 | 
			
		||||
            .expected_result = {GetExpectedFaceDetectionResult(
 | 
			
		||||
                kPortraitExpectedDetection)},
 | 
			
		||||
            .graph_name =
 | 
			
		||||
                "mediapipe.tasks.vision.face_detector.FaceDetectorGraph"}, ),
 | 
			
		||||
    [](const TestParamInfo<FaceDetectorGraphTest::ParamType>& info) {
 | 
			
		||||
      return info.param.test_name;
 | 
			
		||||
    });
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h"
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_runner.h"
 | 
			
		||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
namespace vision {
 | 
			
		||||
namespace face_detector {
 | 
			
		||||
namespace test_util {
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
 | 
			
		||||
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
 | 
			
		||||
                    std::unique_ptr<tflite::OpResolver> op_resolver) {
 | 
			
		||||
  return mediapipe::tasks::core::TaskRunner::Create(std::move(config),
 | 
			
		||||
                                                    std::move(op_resolver));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace test_util
 | 
			
		||||
}  // namespace face_detector
 | 
			
		||||
}  // namespace vision
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,27 @@
 | 
			
		|||
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_runner.h"
 | 
			
		||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
namespace vision {
 | 
			
		||||
namespace face_detector {
 | 
			
		||||
namespace test_util {
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
 | 
			
		||||
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
 | 
			
		||||
                    std::unique_ptr<tflite::OpResolver> op_resolver = nullptr);
 | 
			
		||||
 | 
			
		||||
}  // namespace test_util
 | 
			
		||||
}  // namespace face_detector
 | 
			
		||||
}  // namespace vision
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user