No public description

PiperOrigin-RevId: 583973946
This commit is contained in:
Matt Kreileder 2023-11-20 03:29:57 -08:00 committed by Copybara-Service
parent 5cd3037443
commit d8fd986517
10 changed files with 136 additions and 49 deletions

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
@ -99,10 +101,14 @@ cc_library(
alwayslink = 1,
)
cc_library(
cc_library_with_tflite(
name = "image_preprocessing_graph",
srcs = ["image_preprocessing_graph.cc"],
hdrs = ["image_preprocessing_graph.h"],
tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
],
deps = [
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/image:image_clone_calculator",
@ -120,10 +126,8 @@ cc_library(
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",

View File

@ -271,8 +271,9 @@ class ImagePreprocessingGraph : public Subgraph {
};
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::processors::ImagePreprocessingGraph);
::mediapipe::tasks::components::processors::ImagePreprocessingGraph)
} // namespace processors
} // namespace components

View File

@ -91,16 +91,16 @@ cc_library(
],
)
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
# supports TFLite-in-GMSCore.
cc_library(
cc_library_with_tflite(
name = "model_task_graph",
srcs = ["model_task_graph.cc"],
hdrs = ["model_task_graph.h"],
deps = [
":model_asset_bundle_resources",
tflite_deps = [
":model_resources",
":model_resources_cache",
],
deps = [
":model_asset_bundle_resources",
":model_resources_calculator",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
@ -224,19 +224,16 @@ cc_library_with_tflite(
alwayslink = 1,
)
cc_test_with_tflite(
cc_test(
name = "model_resources_calculator_test",
srcs = ["model_resources_calculator_test.cc"],
data = [
"//mediapipe/tasks/testdata/core:test_models",
],
tflite_deps = [
deps = [
":model_resources",
":model_resources_cache",
":model_resources_calculator",
"@org_tensorflow//tensorflow/lite:test_util",
],
deps = [
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
@ -245,6 +242,7 @@ cc_test_with_tflite(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:test_util",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)
@ -303,22 +301,26 @@ cc_test_with_tflite(
],
)
cc_library(
cc_library_with_tflite(
name = "base_task_api",
hdrs = ["base_task_api.h"],
deps = [
tflite_deps = [
":task_runner",
],
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
],
)
cc_library(
cc_library_with_tflite(
name = "task_api_factory",
hdrs = ["task_api_factory.h"],
deps = [
tflite_deps = [
":base_task_api",
":model_resources",
":task_runner",
],
deps = [
":utils",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:executor",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
@ -147,7 +148,8 @@ class InferenceSubgraph : public Subgraph {
return delegate;
}
};
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph);
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph)
absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
SubgraphContext* sc) {

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -31,9 +33,15 @@ cc_library(
],
)
cc_library(
cc_library_with_tflite(
name = "base_vision_task_api",
hdrs = ["base_vision_task_api.h"],
tflite_deps = [
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
],
deps = [
":image_processing_options",
":running_mode",
@ -42,24 +50,22 @@ cc_library(
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:rect",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
cc_library_with_tflite(
name = "vision_task_api_factory",
hdrs = ["vision_task_api_factory.h"],
deps = [
tflite_deps = [
":base_vision_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
],
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/cc/core:task_api_factory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
package(default_visibility = [
"//mediapipe/tasks:internal",
@ -18,9 +19,15 @@ package(default_visibility = [
licenses(["notice"])
cc_library(
cc_library_with_tflite(
name = "face_detector_graph",
srcs = ["face_detector_graph.cc"],
tflite_deps = [
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
],
deps = [
"//mediapipe/calculators/core:clip_vector_size_calculator",
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
@ -38,6 +45,7 @@ cc_library(
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
@ -45,36 +53,34 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(
cc_library_with_tflite(
name = "face_detector",
srcs = ["face_detector.cc"],
hdrs = ["face_detector.h"],
tflite_deps = [
":face_detector_graph",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
],
visibility = ["//visibility:public"],
deps = [
":face_detector_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/tasks/cc/components/containers:detection_result",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <vector>
#include "absl/status/status.h"
@ -26,6 +27,7 @@ limitations under the License.
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -213,14 +215,16 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
}
private:
std::string GetImagePreprocessingGraphName() {
return "mediapipe.tasks.components.processors.ImagePreprocessingGraph";
}
absl::StatusOr<FaceDetectionOuts> BuildFaceDetectionSubgraph(
const FaceDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Image preprocessing subgraph to convert image to tensor for the tflite
// model.
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
auto& preprocessing = graph.AddNode(GetImagePreprocessingGraphName());
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
@ -337,7 +341,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::face_detector::FaceDetectorGraph);
::mediapipe::tasks::vision::face_detector::FaceDetectorGraph)
} // namespace face_detector
} // namespace vision

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/flags/flag.h"
#include "absl/log/absl_check.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
@ -92,11 +93,10 @@ constexpr float kFaceDetectionMaxDiff = 0.01;
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
absl::string_view model_name) {
absl::string_view model_name, std::string graph_name) {
Graph graph;
auto& face_detector_graph =
graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
auto& face_detector_graph = graph.AddNode(graph_name);
auto options = std::make_unique<FaceDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
@ -136,6 +136,8 @@ struct TestParams {
std::string test_image_name;
// Expected face detection results.
std::vector<Detection> expected_result;
// The name of the mediapipe graph to run.
std::string graph_name;
};
class FaceDetectorGraphTest : public testing::TestWithParam<TestParams> {};
@ -149,8 +151,9 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
input_norm_rect.set_y_center(0.5);
input_norm_rect.set_width(1.0);
input_norm_rect.set_height(1.0);
MP_ASSERT_OK_AND_ASSIGN(
auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner,
CreateTaskRunner(GetParam().face_detection_model_name,
GetParam().graph_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
{kNormRectName,
@ -165,11 +168,15 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
INSTANTIATE_TEST_SUITE_P(
FaceDetectorGraphTest, FaceDetectorGraphTest,
Values(TestParams{.test_name = "ShortRange",
.face_detection_model_name = kShortRangeBlazeFaceModel,
.test_image_name = kPortraitImage,
.expected_result = {GetExpectedFaceDetectionResult(
kPortraitExpectedDetection)}}),
Values(
TestParams{
.test_name = "ShortRange",
.face_detection_model_name = kShortRangeBlazeFaceModel,
.test_image_name = kPortraitImage,
.expected_result = {GetExpectedFaceDetectionResult(
kPortraitExpectedDetection)},
.graph_name =
"mediapipe.tasks.vision.face_detector.FaceDetectorGraph"}, ),
[](const TestParamInfo<FaceDetectorGraphTest::ParamType>& info) {
return info.param.test_name;
});

View File

@ -0,0 +1,28 @@
#include "mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h"
#include <memory>
#include <utility>
#include "absl/status/statusor.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_detector {
namespace test_util {
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver) {
return mediapipe::tasks::core::TaskRunner::Create(std::move(config),
std::move(op_resolver));
}
} // namespace test_util
} // namespace face_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,27 @@
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
#define MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
#include <memory>
#include "absl/status/statusor.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_detector {
namespace test_util {
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr);
} // namespace test_util
} // namespace face_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_