diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 1a1e75d41..f02c6cd04 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) @@ -99,10 +101,14 @@ cc_library( alwayslink = 1, ) -cc_library( +cc_library_with_tflite( name = "image_preprocessing_graph", srcs = ["image_preprocessing_graph.cc"], hdrs = ["image_preprocessing_graph.h"], + tflite_deps = [ + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + ], deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/image:image_clone_calculator", @@ -120,10 +126,8 @@ cc_library( "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/schema:schema_fbs", diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index 1604dfbd5..da9d66c71 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -271,8 +271,9 @@ class ImagePreprocessingGraph : public Subgraph { }; } }; + REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::processors::ImagePreprocessingGraph); + ::mediapipe::tasks::components::processors::ImagePreprocessingGraph) } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index bb0d4b001..9185d0a97 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -91,16 +91,16 @@ cc_library( ], ) -# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator -# supports TFLite-in-GMSCore. -cc_library( +cc_library_with_tflite( name = "model_task_graph", srcs = ["model_task_graph.cc"], hdrs = ["model_task_graph.h"], - deps = [ - ":model_asset_bundle_resources", + tflite_deps = [ ":model_resources", ":model_resources_cache", + ], + deps = [ + ":model_asset_bundle_resources", ":model_resources_calculator", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", @@ -224,19 +224,16 @@ cc_library_with_tflite( alwayslink = 1, ) -cc_test_with_tflite( +cc_test( name = "model_resources_calculator_test", srcs = ["model_resources_calculator_test.cc"], data = [ "//mediapipe/tasks/testdata/core:test_models", ], - tflite_deps = [ + deps = [ ":model_resources", ":model_resources_cache", ":model_resources_calculator", - "@org_tensorflow//tensorflow/lite:test_util", - ], - deps = [ "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", @@ -245,6 +242,7 @@ cc_test_with_tflite( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite:test_util", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], ) @@ -303,22 +301,26 @@ cc_test_with_tflite( ], ) -cc_library( +cc_library_with_tflite( name = "base_task_api", hdrs = ["base_task_api.h"], - deps = [ + tflite_deps = [ ":task_runner", + ], + deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", ], ) -cc_library( +cc_library_with_tflite( name = "task_api_factory", hdrs = ["task_api_factory.h"], - deps = [ + tflite_deps = [ ":base_task_api", ":model_resources", ":task_runner", + ], + deps = [ ":utils", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:executor", diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index b82a69718..57bb25bf8 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -147,7 +148,8 @@ class InferenceSubgraph : public Subgraph { return delegate; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph); + +REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph) absl::StatusOr ModelTaskGraph::GetConfig( SubgraphContext* sc) { diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index 6bcf2f5d6..d7ebfec68 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + licenses(["notice"]) package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -31,9 +33,15 @@ cc_library( ], ) -cc_library( +cc_library_with_tflite( name = "base_vision_task_api", hdrs = ["base_vision_task_api.h"], + tflite_deps = [ + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + ], deps = [ ":image_processing_options", ":running_mode", @@ -42,24 +50,22 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc/components/containers:rect", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) -cc_library( +cc_library_with_tflite( name = "vision_task_api_factory", hdrs = ["vision_task_api_factory.h"], - deps = [ + tflite_deps = [ ":base_vision_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + ], + deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/cc/core:task_api_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/vision/face_detector/BUILD b/mediapipe/tasks/cc/vision/face_detector/BUILD index fbfd94628..bdad3bd06 100644 --- a/mediapipe/tasks/cc/vision/face_detector/BUILD +++ b/mediapipe/tasks/cc/vision/face_detector/BUILD @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") package(default_visibility = [ "//mediapipe/tasks:internal", @@ -18,9 +19,15 @@ package(default_visibility = [ licenses(["notice"]) -cc_library( +cc_library_with_tflite( name = "face_detector_graph", srcs = ["face_detector_graph.cc"], + tflite_deps = [ + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + ], deps = [ "//mediapipe/calculators/core:clip_vector_size_calculator", "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", @@ -38,6 +45,7 @@ cc_library( "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", "//mediapipe/calculators/util:rect_transformation_calculator", "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:detection_cc_proto", @@ -45,36 +53,34 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], alwayslink = 1, ) -cc_library( +cc_library_with_tflite( name = "face_detector", srcs = ["face_detector.cc"], hdrs = ["face_detector.h"], + tflite_deps = [ + ":face_detector_graph", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + ], visibility = ["//visibility:public"], deps = [ - ":face_detector_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "//mediapipe/tasks/cc/components/containers:detection_result", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc index 8586a7ebd..5a8a60101 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/status/status.h" @@ -26,6 +27,7 @@ limitations under the License. #include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -213,14 +215,16 @@ class FaceDetectorGraph : public core::ModelTaskGraph { } private: + std::string GetImagePreprocessingGraphName() { + return "mediapipe.tasks.components.processors.ImagePreprocessingGraph"; + } absl::StatusOr BuildFaceDetectionSubgraph( const FaceDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, Source norm_rect_in, Graph& graph) { // Image preprocessing subgraph to convert image to tensor for the tflite // model. - auto& preprocessing = graph.AddNode( - "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + auto& preprocessing = graph.AddNode(GetImagePreprocessingGraphName()); bool use_gpu = components::processors::DetermineImagePreprocessingGpuBackend( subgraph_options.base_options().acceleration()); @@ -337,7 +341,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph { }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph); + ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph) } // namespace face_detector } // namespace vision diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc index 651ad722d..768b92cfd 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/flags/flag.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/api2/builder.h" @@ -92,11 +93,10 @@ constexpr float kFaceDetectionMaxDiff = 0.01; // Helper function to create a TaskRunner. absl::StatusOr> CreateTaskRunner( - absl::string_view model_name) { + absl::string_view model_name, std::string graph_name) { Graph graph; - auto& face_detector_graph = - graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph"); + auto& face_detector_graph = graph.AddNode(graph_name); auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( @@ -136,6 +136,8 @@ struct TestParams { std::string test_image_name; // Expected face detection results. std::vector expected_result; + // The name of the mediapipe graph to run. + std::string graph_name; }; class FaceDetectorGraphTest : public testing::TestWithParam {}; @@ -149,8 +151,9 @@ TEST_P(FaceDetectorGraphTest, Succeed) { input_norm_rect.set_y_center(0.5); input_norm_rect.set_width(1.0); input_norm_rect.set_height(1.0); - MP_ASSERT_OK_AND_ASSIGN( - auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name)); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, + CreateTaskRunner(GetParam().face_detection_model_name, + GetParam().graph_name)); auto output_packets = task_runner->Process( {{kImageName, MakePacket(std::move(image))}, {kNormRectName, @@ -165,11 +168,15 @@ TEST_P(FaceDetectorGraphTest, Succeed) { INSTANTIATE_TEST_SUITE_P( FaceDetectorGraphTest, FaceDetectorGraphTest, - Values(TestParams{.test_name = "ShortRange", - .face_detection_model_name = kShortRangeBlazeFaceModel, - .test_image_name = kPortraitImage, - .expected_result = {GetExpectedFaceDetectionResult( - kPortraitExpectedDetection)}}), + Values( + TestParams{ + .test_name = "ShortRange", + .face_detection_model_name = kShortRangeBlazeFaceModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedFaceDetectionResult( + kPortraitExpectedDetection)}, + .graph_name = + "mediapipe.tasks.vision.face_detector.FaceDetectorGraph"}, ), [](const TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.cc new file mode 100644 index 000000000..a8b82a2ac --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.cc @@ -0,0 +1,28 @@ +#include "mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "tensorflow/lite/core/api/op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { +namespace test_util { + +absl::StatusOr> +CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config, + std::unique_ptr op_resolver) { + return mediapipe::tasks::core::TaskRunner::Create(std::move(config), + std::move(op_resolver)); +} + +} // namespace test_util +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h new file mode 100644 index 000000000..c34464c9e --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h @@ -0,0 +1,27 @@ +#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "tensorflow/lite/core/api/op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { +namespace test_util { + +absl::StatusOr> +CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config, + std::unique_ptr op_resolver = nullptr); + +} // namespace test_util +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_