No public description
PiperOrigin-RevId: 583973946
This commit is contained in:
parent
5cd3037443
commit
d8fd986517
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
@ -99,10 +101,14 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "image_preprocessing_graph",
|
name = "image_preprocessing_graph",
|
||||||
srcs = ["image_preprocessing_graph.cc"],
|
srcs = ["image_preprocessing_graph.cc"],
|
||||||
hdrs = ["image_preprocessing_graph.h"],
|
hdrs = ["image_preprocessing_graph.h"],
|
||||||
|
tflite_deps = [
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:pass_through_calculator",
|
"//mediapipe/calculators/core:pass_through_calculator",
|
||||||
"//mediapipe/calculators/image:image_clone_calculator",
|
"//mediapipe/calculators/image:image_clone_calculator",
|
||||||
|
@ -120,10 +126,8 @@ cc_library(
|
||||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||||
|
|
|
@ -271,8 +271,9 @@ class ImagePreprocessingGraph : public Subgraph {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
::mediapipe::tasks::components::processors::ImagePreprocessingGraph);
|
::mediapipe::tasks::components::processors::ImagePreprocessingGraph)
|
||||||
|
|
||||||
} // namespace processors
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
|
|
|
@ -91,16 +91,16 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
|
cc_library_with_tflite(
|
||||||
# supports TFLite-in-GMSCore.
|
|
||||||
cc_library(
|
|
||||||
name = "model_task_graph",
|
name = "model_task_graph",
|
||||||
srcs = ["model_task_graph.cc"],
|
srcs = ["model_task_graph.cc"],
|
||||||
hdrs = ["model_task_graph.h"],
|
hdrs = ["model_task_graph.h"],
|
||||||
deps = [
|
tflite_deps = [
|
||||||
":model_asset_bundle_resources",
|
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":model_resources_cache",
|
":model_resources_cache",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":model_asset_bundle_resources",
|
||||||
":model_resources_calculator",
|
":model_resources_calculator",
|
||||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
@ -224,19 +224,16 @@ cc_library_with_tflite(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_test_with_tflite(
|
cc_test(
|
||||||
name = "model_resources_calculator_test",
|
name = "model_resources_calculator_test",
|
||||||
srcs = ["model_resources_calculator_test.cc"],
|
srcs = ["model_resources_calculator_test.cc"],
|
||||||
data = [
|
data = [
|
||||||
"//mediapipe/tasks/testdata/core:test_models",
|
"//mediapipe/tasks/testdata/core:test_models",
|
||||||
],
|
],
|
||||||
tflite_deps = [
|
deps = [
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":model_resources_cache",
|
":model_resources_cache",
|
||||||
":model_resources_calculator",
|
":model_resources_calculator",
|
||||||
"@org_tensorflow//tensorflow/lite:test_util",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
|
@ -245,6 +242,7 @@ cc_test_with_tflite(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -303,22 +301,26 @@ cc_test_with_tflite(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "base_task_api",
|
name = "base_task_api",
|
||||||
hdrs = ["base_task_api.h"],
|
hdrs = ["base_task_api.h"],
|
||||||
deps = [
|
tflite_deps = [
|
||||||
":task_runner",
|
":task_runner",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "task_api_factory",
|
name = "task_api_factory",
|
||||||
hdrs = ["task_api_factory.h"],
|
hdrs = ["task_api_factory.h"],
|
||||||
deps = [
|
tflite_deps = [
|
||||||
":base_task_api",
|
":base_task_api",
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":task_runner",
|
":task_runner",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
":utils",
|
":utils",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework:executor",
|
"//mediapipe/framework:executor",
|
||||||
|
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
@ -147,7 +148,8 @@ class InferenceSubgraph : public Subgraph {
|
||||||
return delegate;
|
return delegate;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph);
|
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph)
|
||||||
|
|
||||||
absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
|
||||||
SubgraphContext* sc) {
|
SubgraphContext* sc) {
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
@ -31,9 +33,15 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "base_vision_task_api",
|
name = "base_vision_task_api",
|
||||||
hdrs = ["base_vision_task_api.h"],
|
hdrs = ["base_vision_task_api.h"],
|
||||||
|
tflite_deps = [
|
||||||
|
"//mediapipe/tasks/cc/core:base_task_api",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":image_processing_options",
|
":image_processing_options",
|
||||||
":running_mode",
|
":running_mode",
|
||||||
|
@ -42,24 +50,22 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:rect",
|
"//mediapipe/tasks/cc/components/containers:rect",
|
||||||
"//mediapipe/tasks/cc/core:base_task_api",
|
|
||||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "vision_task_api_factory",
|
name = "vision_task_api_factory",
|
||||||
hdrs = ["vision_task_api_factory.h"],
|
hdrs = ["vision_task_api_factory.h"],
|
||||||
deps = [
|
tflite_deps = [
|
||||||
":base_vision_task_api",
|
":base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
|
|
||||||
package(default_visibility = [
|
package(default_visibility = [
|
||||||
"//mediapipe/tasks:internal",
|
"//mediapipe/tasks:internal",
|
||||||
|
@ -18,9 +19,15 @@ package(default_visibility = [
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "face_detector_graph",
|
name = "face_detector_graph",
|
||||||
srcs = ["face_detector_graph.cc"],
|
srcs = ["face_detector_graph.cc"],
|
||||||
|
tflite_deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:clip_vector_size_calculator",
|
"//mediapipe/calculators/core:clip_vector_size_calculator",
|
||||||
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
|
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
|
||||||
|
@ -38,6 +45,7 @@ cc_library(
|
||||||
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
|
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||||
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
|
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
@ -45,36 +53,34 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "face_detector",
|
name = "face_detector",
|
||||||
srcs = ["face_detector.cc"],
|
srcs = ["face_detector.cc"],
|
||||||
hdrs = ["face_detector.h"],
|
hdrs = ["face_detector.h"],
|
||||||
|
tflite_deps = [
|
||||||
|
":face_detector_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":face_detector_graph",
|
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/tasks/cc/components/containers:detection_result",
|
"//mediapipe/tasks/cc/components/containers:detection_result",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
|
||||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
|
||||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
||||||
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
|
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
@ -213,14 +215,16 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::string GetImagePreprocessingGraphName() {
|
||||||
|
return "mediapipe.tasks.components.processors.ImagePreprocessingGraph";
|
||||||
|
}
|
||||||
absl::StatusOr<FaceDetectionOuts> BuildFaceDetectionSubgraph(
|
absl::StatusOr<FaceDetectionOuts> BuildFaceDetectionSubgraph(
|
||||||
const FaceDetectorGraphOptions& subgraph_options,
|
const FaceDetectorGraphOptions& subgraph_options,
|
||||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||||
// Image preprocessing subgraph to convert image to tensor for the tflite
|
// Image preprocessing subgraph to convert image to tensor for the tflite
|
||||||
// model.
|
// model.
|
||||||
auto& preprocessing = graph.AddNode(
|
auto& preprocessing = graph.AddNode(GetImagePreprocessingGraphName());
|
||||||
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
|
||||||
bool use_gpu =
|
bool use_gpu =
|
||||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||||
subgraph_options.base_options().acceleration());
|
subgraph_options.base_options().acceleration());
|
||||||
|
@ -337,7 +341,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
::mediapipe::tasks::vision::face_detector::FaceDetectorGraph);
|
::mediapipe::tasks::vision::face_detector::FaceDetectorGraph)
|
||||||
|
|
||||||
} // namespace face_detector
|
} // namespace face_detector
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/log/absl_check.h"
|
#include "absl/log/absl_check.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
@ -92,11 +93,10 @@ constexpr float kFaceDetectionMaxDiff = 0.01;
|
||||||
|
|
||||||
// Helper function to create a TaskRunner.
|
// Helper function to create a TaskRunner.
|
||||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
||||||
absl::string_view model_name) {
|
absl::string_view model_name, std::string graph_name) {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
|
||||||
auto& face_detector_graph =
|
auto& face_detector_graph = graph.AddNode(graph_name);
|
||||||
graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
|
|
||||||
|
|
||||||
auto options = std::make_unique<FaceDetectorGraphOptions>();
|
auto options = std::make_unique<FaceDetectorGraphOptions>();
|
||||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||||
|
@ -136,6 +136,8 @@ struct TestParams {
|
||||||
std::string test_image_name;
|
std::string test_image_name;
|
||||||
// Expected face detection results.
|
// Expected face detection results.
|
||||||
std::vector<Detection> expected_result;
|
std::vector<Detection> expected_result;
|
||||||
|
// The name of the mediapipe graph to run.
|
||||||
|
std::string graph_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FaceDetectorGraphTest : public testing::TestWithParam<TestParams> {};
|
class FaceDetectorGraphTest : public testing::TestWithParam<TestParams> {};
|
||||||
|
@ -149,8 +151,9 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
|
||||||
input_norm_rect.set_y_center(0.5);
|
input_norm_rect.set_y_center(0.5);
|
||||||
input_norm_rect.set_width(1.0);
|
input_norm_rect.set_width(1.0);
|
||||||
input_norm_rect.set_height(1.0);
|
input_norm_rect.set_height(1.0);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(auto task_runner,
|
||||||
auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name));
|
CreateTaskRunner(GetParam().face_detection_model_name,
|
||||||
|
GetParam().graph_name));
|
||||||
auto output_packets = task_runner->Process(
|
auto output_packets = task_runner->Process(
|
||||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectName,
|
{kNormRectName,
|
||||||
|
@ -165,11 +168,15 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
FaceDetectorGraphTest, FaceDetectorGraphTest,
|
FaceDetectorGraphTest, FaceDetectorGraphTest,
|
||||||
Values(TestParams{.test_name = "ShortRange",
|
Values(
|
||||||
.face_detection_model_name = kShortRangeBlazeFaceModel,
|
TestParams{
|
||||||
.test_image_name = kPortraitImage,
|
.test_name = "ShortRange",
|
||||||
.expected_result = {GetExpectedFaceDetectionResult(
|
.face_detection_model_name = kShortRangeBlazeFaceModel,
|
||||||
kPortraitExpectedDetection)}}),
|
.test_image_name = kPortraitImage,
|
||||||
|
.expected_result = {GetExpectedFaceDetectionResult(
|
||||||
|
kPortraitExpectedDetection)},
|
||||||
|
.graph_name =
|
||||||
|
"mediapipe.tasks.vision.face_detector.FaceDetectorGraph"}, ),
|
||||||
[](const TestParamInfo<FaceDetectorGraphTest::ParamType>& info) {
|
[](const TestParamInfo<FaceDetectorGraphTest::ParamType>& info) {
|
||||||
return info.param.test_name;
|
return info.param.test_name;
|
||||||
});
|
});
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
#include "mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace face_detector {
|
||||||
|
namespace test_util {
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
|
||||||
|
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
|
||||||
|
std::unique_ptr<tflite::OpResolver> op_resolver) {
|
||||||
|
return mediapipe::tasks::core::TaskRunner::Create(std::move(config),
|
||||||
|
std::move(op_resolver));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace test_util
|
||||||
|
} // namespace face_detector
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,27 @@
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace face_detector {
|
||||||
|
namespace test_util {
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
|
||||||
|
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
|
||||||
|
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr);
|
||||||
|
|
||||||
|
} // namespace test_util
|
||||||
|
} // namespace face_detector
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
|
Loading…
Reference in New Issue
Block a user