Move ImagePreprocessing to "processors" folder.

PiperOrigin-RevId: 490444821
This commit is contained in:
MediaPipe Team 2022-11-23 02:03:35 -08:00 committed by Copybara-Service
parent c5ce523697
commit b5189758f7
23 changed files with 493 additions and 150 deletions

View File

@ -12,55 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "image_preprocessing_options_proto",
srcs = ["image_preprocessing_options.proto"],
deps = [
"//mediapipe/calculators/tensor:image_to_tensor_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.cc"],
hdrs = ["image_preprocessing.h"],
deps = [
":image_preprocessing_options_cc_proto",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/image:image_clone_calculator",
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
# TODO: Enable this test
# TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed.
cc_library(

View File

@ -100,3 +100,36 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "image_preprocessing_graph",
srcs = ["image_preprocessing_graph.cc"],
hdrs = ["image_preprocessing_graph.h"],
deps = [
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/image:image_clone_calculator",
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
# TODO: Enable this test

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include <array>
#include <complex>
@ -33,7 +33,7 @@ limitations under the License.
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
@ -42,6 +42,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
using ::mediapipe::Tensor;
@ -144,9 +145,9 @@ bool DetermineImagePreprocessingGpuBackend(
return acceleration.has_gpu();
}
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
bool use_gpu,
ImagePreprocessingOptions* options) {
absl::Status ConfigureImagePreprocessingGraph(
const ModelResources& model_resources, bool use_gpu,
proto::ImagePreprocessingGraphOptions* options) {
ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildImageTensorSpecs(model_resources));
MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator(
@ -154,9 +155,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
// The GPU backend isn't able to process int data. If the input tensor is
// quantized, forces the image preprocessing graph to use CPU backend.
if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
options->set_backend(proto::ImagePreprocessingGraphOptions::GPU_BACKEND);
} else {
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
options->set_backend(proto::ImagePreprocessingGraphOptions::CPU_BACKEND);
}
return absl::OkStatus();
}
@ -170,8 +171,7 @@ Source<Image> AddDataConverter(Source<Image> image_in, Graph& graph,
return image_converter[Output<Image>("")];
}
// A "mediapipe.tasks.components.ImagePreprocessingSubgraph" performs image
// preprocessing.
// An ImagePreprocessingGraph performs image preprocessing.
// - Accepts CPU input images and outputs CPU tensors.
//
// Inputs:
@ -192,7 +192,7 @@ Source<Image> AddDataConverter(Source<Image> image_in, Graph& graph,
// An std::array<float, 4> representing the letterbox padding from the 4
// sides ([left, top, right, bottom]) of the output image, normalized to
// [0.f, 1.f] by the output dimensions. The padding values are non-zero only
// when the "keep_aspect_ratio" is true in ImagePreprocessingOptions.
// when the "keep_aspect_ratio" is true in ImagePreprocessingGraphOptions.
// IMAGE_SIZE - std::pair<int,int> @Optional
// The size of the original input image as a <width, height> pair.
// IMAGE - Image @Optional
@ -200,15 +200,15 @@ Source<Image> AddDataConverter(Source<Image> image_in, Graph& graph,
// GPU).
//
// The recommended way of using this subgraph is through the GraphBuilder API
// using the 'ConfigureImagePreprocessing()' function. See header file for more
// details.
class ImagePreprocessingSubgraph : public Subgraph {
// using the 'ConfigureImagePreprocessingGraph()' function. See header file for
// more details.
class ImagePreprocessingGraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
auto output_streams = BuildImagePreprocessing(
sc->Options<ImagePreprocessingOptions>(),
sc->Options<proto::ImagePreprocessingGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph);
output_streams.tensors >> graph[Output<std::vector<Tensor>>(kTensorsTag)];
@ -233,24 +233,25 @@ class ImagePreprocessingSubgraph : public Subgraph {
// - the image that has pixel data stored on the target storage
// (mediapipe::Image).
//
// options: the mediapipe tasks ImagePreprocessingOptions.
// options: the mediapipe tasks ImagePreprocessingGraphOptions.
// image_in: (mediapipe::Image) stream to preprocess.
// graph: the mediapipe builder::Graph instance to be updated.
ImagePreprocessingOutputStreams BuildImagePreprocessing(
const ImagePreprocessingOptions& options, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
const proto::ImagePreprocessingGraphOptions& options,
Source<Image> image_in, Source<NormalizedRect> norm_rect_in,
Graph& graph) {
// Convert image to tensor.
auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator");
image_to_tensor.GetOptions<mediapipe::ImageToTensorCalculatorOptions>()
.CopyFrom(options.image_to_tensor_options());
switch (options.backend()) {
case ImagePreprocessingOptions::CPU_BACKEND: {
case proto::ImagePreprocessingGraphOptions::CPU_BACKEND: {
auto cpu_image =
AddDataConverter(image_in, graph, /*output_on_gpu=*/false);
cpu_image >> image_to_tensor.In(kImageTag);
break;
}
case ImagePreprocessingOptions::GPU_BACKEND: {
case proto::ImagePreprocessingGraphOptions::GPU_BACKEND: {
auto gpu_image =
AddDataConverter(image_in, graph, /*output_on_gpu=*/true);
gpu_image >> image_to_tensor.In(kImageTag);
@ -284,8 +285,9 @@ class ImagePreprocessingSubgraph : public Subgraph {
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::ImagePreprocessingSubgraph);
::mediapipe::tasks::components::processors::ImagePreprocessingGraph);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -13,35 +13,36 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_
#include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Configures an ImagePreprocessing subgraph using the provided model resources
// Configures an ImagePreprocessingGraph using the provided model resources
// When use_gpu is true, use GPU as backend to convert image to tensor.
// - Accepts CPU input images and outputs CPU tensors.
//
// Example usage:
//
// auto& preprocessing =
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
// graph.AddNode("mediapipe.tasks.components.processors.ImagePreprocessingGraph");
// core::proto::Acceleration acceleration;
// acceleration.mutable_xnnpack();
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessingGraph(
// model_resources,
// use_gpu,
// &preprocessing.GetOptions<ImagePreprocessingOptions>()));
// &preprocessing.GetOptions<ImagePreprocessingGraphOptions>()));
//
// The resulting ImagePreprocessing subgraph has the following I/O:
// The resulting ImagePreprocessingGraph has the following I/O:
// Inputs:
// IMAGE - Image
// The image to preprocess.
@ -61,17 +62,18 @@ namespace components {
// IMAGE - Image @Optional
// The image that has the pixel data stored on the target storage (CPU vs
// GPU).
absl::Status ConfigureImagePreprocessing(
absl::Status ConfigureImagePreprocessingGraph(
const core::ModelResources& model_resources, bool use_gpu,
ImagePreprocessingOptions* options);
proto::ImagePreprocessingGraphOptions* options);
// Determine if the image preprocessing subgraph should use GPU as the backend
// Determine if the image preprocessing graph should use GPU as the backend
// according to the given acceleration setting.
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_

View File

@ -0,0 +1,343 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include <memory>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using ::testing::ContainerEq;
using ::testing::HasSubstr;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite";
constexpr char kMobileNetFloatWithoutMetadata[] =
"mobilenet_v1_0.25_224_1_default_1.tflite";
constexpr char kMobileNetQuantizedWithMetadata[] =
"mobilenet_v1_0.25_224_quant.tflite";
constexpr char kMobileNetQuantizedWithoutMetadata[] =
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
constexpr char kTestImage[] = "burger.jpg";
constexpr int kTestImageWidth = 480;
constexpr int kTestImageHeight = 325;
constexpr char kTestModelResourcesTag[] = "test_model_resources";
constexpr std::array<float, 16> kIdentityMatrix = {1, 0, 0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0, 0, 1};
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image_in";
constexpr char kMatrixTag[] = "MATRIX";
constexpr char kMatrixName[] = "matrix_out";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTensorsName[] = "tensors_out";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kImageSizeName[] = "image_size_out";
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
constexpr char kLetterboxPaddingName[] = "letterbox_padding_out";
constexpr float kLetterboxMaxAbsError = 1e-5;
// Helper function to get ModelResources.
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
absl::string_view model_name) {
auto external_file = std::make_unique<core::proto::ExternalFile>();
external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name));
return ModelResources::Create(kTestModelResourcesTag,
std::move(external_file));
}
// Helper function to create a TaskRunner from ModelResources.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
const ModelResources& model_resources, bool keep_aspect_ratio) {
Graph graph;
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
auto& options =
preprocessing.GetOptions<proto::ImagePreprocessingGraphOptions>();
options.mutable_image_to_tensor_options()->set_keep_aspect_ratio(
keep_aspect_ratio);
MP_RETURN_IF_ERROR(
ConfigureImagePreprocessingGraph(model_resources, false, &options));
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
preprocessing.In(kImageTag);
preprocessing.Out(kTensorsTag).SetName(kTensorsName) >>
graph[Output<std::vector<Tensor>>(kTensorsTag)];
preprocessing.Out(kMatrixTag).SetName(kMatrixName) >>
graph[Output<std::array<float, 16>>(kMatrixTag)];
preprocessing.Out(kImageSizeTag).SetName(kImageSizeName) >>
graph[Output<std::pair<int, int>>(kImageSizeTag)];
preprocessing.Out(kLetterboxPaddingTag).SetName(kLetterboxPaddingName) >>
graph[Output<std::array<float, 4>>(kLetterboxPaddingTag)];
return TaskRunner::Create(graph.GetConfig());
}
class ConfigureTest : public tflite_shims::testing::Test {};
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata));
proto::ImagePreprocessingGraphOptions options;
MP_EXPECT_OK(
ConfigureImagePreprocessingGraph(*model_resources, false, &options));
EXPECT_THAT(options, EqualsProto(
R"pb(image_to_tensor_options {
output_tensor_width: 224
output_tensor_height: 224
output_tensor_uint_range { min: 0 max: 255 }
gpu_origin: TOP_LEFT
}
backend: CPU_BACKEND)pb"));
}
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileNetQuantizedWithoutMetadata));
proto::ImagePreprocessingGraphOptions options;
MP_EXPECT_OK(
ConfigureImagePreprocessingGraph(*model_resources, false, &options));
EXPECT_THAT(options, EqualsProto(
R"pb(image_to_tensor_options {
output_tensor_width: 192
output_tensor_height: 192
output_tensor_uint_range { min: 0 max: 255 }
gpu_origin: TOP_LEFT
}
backend: CPU_BACKEND)pb"));
}
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileNetFloatWithMetadata));
proto::ImagePreprocessingGraphOptions options;
MP_EXPECT_OK(
ConfigureImagePreprocessingGraph(*model_resources, false, &options));
EXPECT_THAT(options, EqualsProto(
R"pb(image_to_tensor_options {
output_tensor_width: 224
output_tensor_height: 224
output_tensor_float_range { min: -1 max: 1 }
gpu_origin: TOP_LEFT
}
backend: CPU_BACKEND)pb"));
}
TEST_F(ConfigureTest, SucceedsWithQuantizedModelFallbacksCpuBackend) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata));
proto::ImagePreprocessingGraphOptions options;
core::proto::Acceleration acceleration;
acceleration.mutable_gpu();
bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
EXPECT_TRUE(use_gpu);
MP_EXPECT_OK(
ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options));
EXPECT_THAT(options, EqualsProto(
R"pb(image_to_tensor_options {
output_tensor_width: 224
output_tensor_height: 224
output_tensor_uint_range { min: 0 max: 255 }
gpu_origin: TOP_LEFT
}
backend: CPU_BACKEND)pb"));
}
TEST_F(ConfigureTest, SucceedsWithFloatModelGpuBackend) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileNetFloatWithMetadata));
proto::ImagePreprocessingGraphOptions options;
core::proto::Acceleration acceleration;
acceleration.mutable_gpu();
bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
EXPECT_TRUE(use_gpu);
MP_EXPECT_OK(
ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options));
EXPECT_THAT(options, EqualsProto(
R"pb(image_to_tensor_options {
output_tensor_width: 224
output_tensor_height: 224
output_tensor_float_range { min: -1 max: 1 }
gpu_origin: TOP_LEFT
}
backend: GPU_BACKEND)pb"));
}
TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileNetFloatWithoutMetadata));
proto::ImagePreprocessingGraphOptions options;
auto status =
ConfigureImagePreprocessingGraph(*model_resources, false, &options);
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(),
HasSubstr("requires specifying NormalizationOptions metadata"));
}
// Struct holding the parameters for parameterized PreprocessingTest class.
struct PreprocessingParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of the model to test.
std::string input_model_name;
// If true, keep test image aspect ratio.
bool keep_aspect_ratio;
// The expected output tensor type.
Tensor::ElementType expected_type;
// The expected outoput tensor shape.
std::vector<int> expected_shape;
// The expected output letterbox padding;
std::array<float, 4> expected_letterbox_padding;
};
class PreprocessingTest : public testing::TestWithParam<PreprocessingParams> {};
TEST_P(PreprocessingTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kTestImage)));
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(GetParam().input_model_name));
MP_ASSERT_OK_AND_ASSIGN(
auto task_runner,
CreateTaskRunner(*model_resources, GetParam().keep_aspect_ratio));
auto output_packets =
task_runner->Process({{kImageName, MakePacket<Image>(std::move(image))}});
MP_ASSERT_OK(output_packets);
const std::vector<Tensor>& tensors =
(*output_packets)[kTensorsName].Get<std::vector<Tensor>>();
EXPECT_EQ(tensors.size(), 1);
EXPECT_EQ(tensors[0].element_type(), GetParam().expected_type);
EXPECT_THAT(tensors[0].shape().dims, ContainerEq(GetParam().expected_shape));
auto& matrix = (*output_packets)[kMatrixName].Get<std::array<float, 16>>();
if (!GetParam().keep_aspect_ratio) {
for (int i = 0; i < matrix.size(); ++i) {
EXPECT_FLOAT_EQ(matrix[i], kIdentityMatrix[i]);
}
}
auto& image_size =
(*output_packets)[kImageSizeName].Get<std::pair<int, int>>();
EXPECT_EQ(image_size.first, kTestImageWidth);
EXPECT_EQ(image_size.second, kTestImageHeight);
std::array<float, 4> letterbox_padding =
(*output_packets)[kLetterboxPaddingName].Get<std::array<float, 4>>();
for (int i = 0; i < letterbox_padding.size(); ++i) {
EXPECT_NEAR(letterbox_padding[i], GetParam().expected_letterbox_padding[i],
kLetterboxMaxAbsError);
}
}
INSTANTIATE_TEST_SUITE_P(
PreprocessingTest, PreprocessingTest,
Values(
PreprocessingParams{.test_name = "kMobileNetQuantizedWithMetadata",
.input_model_name = kMobileNetQuantizedWithMetadata,
.keep_aspect_ratio = false,
.expected_type = Tensor::ElementType::kUInt8,
.expected_shape = {1, 224, 224, 3},
.expected_letterbox_padding = {0, 0, 0, 0}},
PreprocessingParams{
.test_name = "kMobileNetQuantizedWithoutMetadata",
.input_model_name = kMobileNetQuantizedWithoutMetadata,
.keep_aspect_ratio = false,
.expected_type = Tensor::ElementType::kUInt8,
.expected_shape = {1, 192, 192, 3},
.expected_letterbox_padding = {0, 0, 0, 0}},
PreprocessingParams{.test_name = "kMobileNetFloatWithMetadata",
.input_model_name = kMobileNetFloatWithMetadata,
.keep_aspect_ratio = false,
.expected_type = Tensor::ElementType::kFloat32,
.expected_shape = {1, 224, 224, 3},
.expected_letterbox_padding = {0, 0, 0, 0}},
PreprocessingParams{
.test_name = "kMobileNetFloatWithMetadataKeepAspectRatio",
.input_model_name = kMobileNetFloatWithMetadata,
.keep_aspect_ratio = true,
.expected_type = Tensor::ElementType::kFloat32,
.expected_shape = {1, 224, 224, 3},
.expected_letterbox_padding = {/*left*/ 0,
/*top*/ 0.161458,
/*right*/ 0,
/*bottom*/ 0.161458}}),
[](const TestParamInfo<PreprocessingTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -49,3 +49,13 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)
mediapipe_proto_library(
name = "image_preprocessing_graph_options_proto",
srcs = ["image_preprocessing_graph_options.proto"],
deps = [
"//mediapipe/calculators/tensor:image_to_tensor_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)

View File

@ -15,14 +15,14 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks.components;
package mediapipe.tasks.components.processors.proto;
import "mediapipe/calculators/tensor/image_to_tensor_calculator.proto";
import "mediapipe/framework/calculator.proto";
message ImagePreprocessingOptions {
message ImagePreprocessingGraphOptions {
extend mediapipe.CalculatorOptions {
optional ImagePreprocessingOptions ext = 456882436;
optional ImagePreprocessingGraphOptions ext = 456882436;
}
// Options for the ImageToTensor calculator encapsulated by the

View File

@ -37,7 +37,6 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
@ -105,10 +104,7 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h"

View File

@ -29,8 +29,6 @@ limitations under the License.
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"

View File

@ -46,7 +46,7 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -226,21 +226,23 @@ class HandDetectorGraph : public core::ModelTaskGraph {
Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Add image preprocessing subgraph. The model expects aspect ratio
// unchanged.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
auto& image_to_tensor_options =
*preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()
.GetOptions<components::processors::proto::
ImagePreprocessingGraphOptions>()
.mutable_image_to_tensor_options();
image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
&preprocessing.GetOptions<
components::processors::proto::ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In("IMAGE");
norm_rect_in >> preprocessing.In("NORM_RECT");
auto preprocessed_tensors = preprocessing.Out("TENSORS");

View File

@ -35,7 +35,6 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
@ -89,7 +88,7 @@ cc_library(
"//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h"

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
@ -281,14 +281,15 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
Source<NormalizedRect> hand_rect, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In("IMAGE");
hand_rect >> preprocessing.In("NORM_RECT");
auto image_size = preprocessing[Output<std::pair<int, int>>("IMAGE_SIZE")];

View File

@ -59,11 +59,11 @@ cc_library(
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",

View File

@ -23,10 +23,10 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
@ -135,14 +135,15 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Adds preprocessing calculators and connects them to the graph input image
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);

View File

@ -57,12 +57,12 @@ cc_library(
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",

View File

@ -20,10 +20,10 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -130,14 +130,15 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Adds preprocessing calculators and connects them to the graph input image
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);

View File

@ -56,10 +56,10 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator",
"//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",

View File

@ -27,8 +27,8 @@ limitations under the License.
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
@ -243,14 +243,15 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
// Adds preprocessing calculators and connects them to the graph input image
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);

View File

@ -71,9 +71,9 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
@ -561,14 +561,15 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
// Adds preprocessing calculators and connects them to the graph input image
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);