No public description
PiperOrigin-RevId: 583973946
This commit is contained in:
parent
5cd3037443
commit
d8fd986517
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
@ -99,10 +101,14 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "image_preprocessing_graph",
|
||||
srcs = ["image_preprocessing_graph.cc"],
|
||||
hdrs = ["image_preprocessing_graph.h"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/calculators/image:image_clone_calculator",
|
||||
|
@ -120,10 +126,8 @@ cc_library(
|
|||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
|
|
|
@ -271,8 +271,9 @@ class ImagePreprocessingGraph : public Subgraph {
|
|||
};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::components::processors::ImagePreprocessingGraph);
|
||||
::mediapipe::tasks::components::processors::ImagePreprocessingGraph)
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
|
|
|
@ -91,16 +91,16 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
|
||||
# supports TFLite-in-GMSCore.
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "model_task_graph",
|
||||
srcs = ["model_task_graph.cc"],
|
||||
hdrs = ["model_task_graph.h"],
|
||||
deps = [
|
||||
":model_asset_bundle_resources",
|
||||
tflite_deps = [
|
||||
":model_resources",
|
||||
":model_resources_cache",
|
||||
],
|
||||
deps = [
|
||||
":model_asset_bundle_resources",
|
||||
":model_resources_calculator",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
|
@ -224,19 +224,16 @@ cc_library_with_tflite(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test_with_tflite(
|
||||
cc_test(
|
||||
name = "model_resources_calculator_test",
|
||||
srcs = ["model_resources_calculator_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/core:test_models",
|
||||
],
|
||||
tflite_deps = [
|
||||
deps = [
|
||||
":model_resources",
|
||||
":model_resources_cache",
|
||||
":model_resources_calculator",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
|
@ -245,6 +242,7 @@ cc_test_with_tflite(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
@ -303,22 +301,26 @@ cc_test_with_tflite(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "base_task_api",
|
||||
hdrs = ["base_task_api.h"],
|
||||
deps = [
|
||||
tflite_deps = [
|
||||
":task_runner",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "task_api_factory",
|
||||
hdrs = ["task_api_factory.h"],
|
||||
deps = [
|
||||
tflite_deps = [
|
||||
":base_task_api",
|
||||
":model_resources",
|
||||
":task_runner",
|
||||
],
|
||||
deps = [
|
||||
":utils",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:executor",
|
||||
|
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -147,7 +148,8 @@ class InferenceSubgraph : public Subgraph {
|
|||
return delegate;
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph);
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::core::InferenceSubgraph)
|
||||
|
||||
absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
|
||||
SubgraphContext* sc) {
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
@ -31,9 +33,15 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "base_vision_task_api",
|
||||
hdrs = ["base_vision_task_api.h"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
],
|
||||
deps = [
|
||||
":image_processing_options",
|
||||
":running_mode",
|
||||
|
@ -42,24 +50,22 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "vision_task_api_factory",
|
||||
hdrs = ["vision_task_api_factory.h"],
|
||||
deps = [
|
||||
tflite_deps = [
|
||||
":base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
|
@ -18,9 +19,15 @@ package(default_visibility = [
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "face_detector_graph",
|
||||
srcs = ["face_detector_graph.cc"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:clip_vector_size_calculator",
|
||||
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
|
||||
|
@ -38,6 +45,7 @@ cc_library(
|
|||
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
@ -45,36 +53,34 @@ cc_library(
|
|||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "face_detector",
|
||||
srcs = ["face_detector.cc"],
|
||||
hdrs = ["face_detector.h"],
|
||||
tflite_deps = [
|
||||
":face_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":face_detector_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/tasks/cc/components/containers:detection_result",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
|||
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
@ -213,14 +215,16 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
|
|||
}
|
||||
|
||||
private:
|
||||
std::string GetImagePreprocessingGraphName() {
|
||||
return "mediapipe.tasks.components.processors.ImagePreprocessingGraph";
|
||||
}
|
||||
absl::StatusOr<FaceDetectionOuts> BuildFaceDetectionSubgraph(
|
||||
const FaceDetectorGraphOptions& subgraph_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
// Image preprocessing subgraph to convert image to tensor for the tflite
|
||||
// model.
|
||||
auto& preprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
||||
auto& preprocessing = graph.AddNode(GetImagePreprocessingGraphName());
|
||||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
|
@ -337,7 +341,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
|
|||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::face_detector::FaceDetectorGraph);
|
||||
::mediapipe::tasks::vision::face_detector::FaceDetectorGraph)
|
||||
|
||||
} // namespace face_detector
|
||||
} // namespace vision
|
||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include "absl/flags/flag.h"
|
||||
#include "absl/log/absl_check.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
|
@ -92,11 +93,10 @@ constexpr float kFaceDetectionMaxDiff = 0.01;
|
|||
|
||||
// Helper function to create a TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
||||
absl::string_view model_name) {
|
||||
absl::string_view model_name, std::string graph_name) {
|
||||
Graph graph;
|
||||
|
||||
auto& face_detector_graph =
|
||||
graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
|
||||
auto& face_detector_graph = graph.AddNode(graph_name);
|
||||
|
||||
auto options = std::make_unique<FaceDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
|
@ -136,6 +136,8 @@ struct TestParams {
|
|||
std::string test_image_name;
|
||||
// Expected face detection results.
|
||||
std::vector<Detection> expected_result;
|
||||
// The name of the mediapipe graph to run.
|
||||
std::string graph_name;
|
||||
};
|
||||
|
||||
class FaceDetectorGraphTest : public testing::TestWithParam<TestParams> {};
|
||||
|
@ -149,8 +151,9 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
|
|||
input_norm_rect.set_y_center(0.5);
|
||||
input_norm_rect.set_width(1.0);
|
||||
input_norm_rect.set_height(1.0);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner,
|
||||
CreateTaskRunner(GetParam().face_detection_model_name,
|
||||
GetParam().graph_name));
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName,
|
||||
|
@ -165,11 +168,15 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
|
|||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
FaceDetectorGraphTest, FaceDetectorGraphTest,
|
||||
Values(TestParams{.test_name = "ShortRange",
|
||||
.face_detection_model_name = kShortRangeBlazeFaceModel,
|
||||
.test_image_name = kPortraitImage,
|
||||
.expected_result = {GetExpectedFaceDetectionResult(
|
||||
kPortraitExpectedDetection)}}),
|
||||
Values(
|
||||
TestParams{
|
||||
.test_name = "ShortRange",
|
||||
.face_detection_model_name = kShortRangeBlazeFaceModel,
|
||||
.test_image_name = kPortraitImage,
|
||||
.expected_result = {GetExpectedFaceDetectionResult(
|
||||
kPortraitExpectedDetection)},
|
||||
.graph_name =
|
||||
"mediapipe.tasks.vision.face_detector.FaceDetectorGraph"}, ),
|
||||
[](const TestParamInfo<FaceDetectorGraphTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
#include "mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test_task_runner_gms.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace face_detector {
|
||||
namespace test_util {
|
||||
|
||||
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
|
||||
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver) {
|
||||
return mediapipe::tasks::core::TaskRunner::Create(std::move(config),
|
||||
std::move(op_resolver));
|
||||
}
|
||||
|
||||
} // namespace test_util
|
||||
} // namespace face_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,27 @@
|
|||
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace face_detector {
|
||||
namespace test_util {
|
||||
|
||||
absl::StatusOr<std::unique_ptr<mediapipe::tasks::core::TaskRunner>>
|
||||
CreateTaskRunnerGms(mediapipe::CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr);
|
||||
|
||||
} // namespace test_util
|
||||
} // namespace face_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_GRAPH_TEST_TASK_RUNNER_GMS_H_
|
Loading…
Reference in New Issue
Block a user