GPU_ORIGIN configurable through base options proto.

PiperOrigin-RevId: 573251085
This commit is contained in:
MediaPipe Team 2023-10-13 10:08:14 -07:00 committed by Copybara-Service
parent 8823046e4b
commit 1bd800697e
23 changed files with 86 additions and 18 deletions

View File

@ -122,6 +122,7 @@ cc_library(
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -73,7 +74,7 @@ struct ImagePreprocessingOutputStreams {
// Fills in the ImageToTensorCalculatorOptions based on the ImageTensorSpecs. // Fills in the ImageToTensorCalculatorOptions based on the ImageTensorSpecs.
absl::Status ConfigureImageToTensorCalculator( absl::Status ConfigureImageToTensorCalculator(
const ImageTensorSpecs& image_tensor_specs, const ImageTensorSpecs& image_tensor_specs, GpuOrigin::Mode gpu_origin,
mediapipe::ImageToTensorCalculatorOptions* options) { mediapipe::ImageToTensorCalculatorOptions* options) {
options->set_output_tensor_width(image_tensor_specs.image_width); options->set_output_tensor_width(image_tensor_specs.image_width);
options->set_output_tensor_height(image_tensor_specs.image_height); options->set_output_tensor_height(image_tensor_specs.image_height);
@ -109,7 +110,7 @@ absl::Status ConfigureImageToTensorCalculator(
} }
// TODO: need to support different GPU origin on different // TODO: need to support different GPU origin on different
// platforms or applications. // platforms or applications.
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT); options->set_gpu_origin(gpu_origin);
return absl::OkStatus(); return absl::OkStatus();
} }
@ -125,10 +126,19 @@ bool DetermineImagePreprocessingGpuBackend(
absl::Status ConfigureImagePreprocessingGraph( absl::Status ConfigureImagePreprocessingGraph(
const ModelResources& model_resources, bool use_gpu, const ModelResources& model_resources, bool use_gpu,
proto::ImagePreprocessingGraphOptions* options) { proto::ImagePreprocessingGraphOptions* options) {
return ConfigureImagePreprocessingGraph(model_resources, use_gpu,
GpuOrigin::TOP_LEFT, options);
}
absl::Status ConfigureImagePreprocessingGraph(
const ModelResources& model_resources, bool use_gpu,
GpuOrigin::Mode gpu_origin,
proto::ImagePreprocessingGraphOptions* options) {
MP_ASSIGN_OR_RETURN(auto image_tensor_specs, MP_ASSIGN_OR_RETURN(auto image_tensor_specs,
vision::BuildInputImageTensorSpecs(model_resources)); vision::BuildInputImageTensorSpecs(model_resources));
MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator( MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator(
image_tensor_specs, options->mutable_image_to_tensor_options())); image_tensor_specs, gpu_origin,
options->mutable_image_to_tensor_options()));
// The GPU backend isn't able to process int data. If the input tensor is // The GPU backend isn't able to process int data. If the input tensor is
// quantized, forces the image preprocessing graph to use CPU backend. // quantized, forces the image preprocessing graph to use CPU backend.
if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
@ -62,6 +63,12 @@ namespace processors {
// IMAGE - Image @Optional // IMAGE - Image @Optional
// The image that has the pixel data stored on the target storage (CPU vs // The image that has the pixel data stored on the target storage (CPU vs
// GPU). // GPU).
absl::Status ConfigureImagePreprocessingGraph(
const core::ModelResources& model_resources, bool use_gpu,
::mediapipe::GpuOrigin::Mode gpu_origin,
proto::ImagePreprocessingGraphOptions* options);
// A convenient function of the above. gpu_origin is set to TOP_LEFT by default.
absl::Status ConfigureImagePreprocessingGraph( absl::Status ConfigureImagePreprocessingGraph(
const core::ModelResources& model_resources, bool use_gpu, const core::ModelResources& model_resources, bool use_gpu,
proto::ImagePreprocessingGraphOptions* options); proto::ImagePreprocessingGraphOptions* options);

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
@ -230,6 +231,25 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelGpuBackend) {
backend: GPU_BACKEND)pb")); backend: GPU_BACKEND)pb"));
} }
TEST_F(ConfigureTest, SucceedsGpuOriginConventional) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileNetFloatWithMetadata));
proto::ImagePreprocessingGraphOptions options;
MP_EXPECT_OK(ConfigureImagePreprocessingGraph(
*model_resources, true, mediapipe::GpuOrigin::CONVENTIONAL, &options));
EXPECT_THAT(options, EqualsProto(
R"pb(image_to_tensor_options {
output_tensor_width: 224
output_tensor_height: 224
output_tensor_float_range { min: -1 max: 1 }
gpu_origin: CONVENTIONAL
}
backend: GPU_BACKEND)pb"));
}
TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) { TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,

View File

@ -46,6 +46,7 @@ mediapipe_proto_library(
deps = [ deps = [
":acceleration_proto", ":acceleration_proto",
":external_file_proto", ":external_file_proto",
"//mediapipe/gpu:gpu_origin_proto",
], ],
) )

View File

@ -17,6 +17,7 @@ syntax = "proto2";
package mediapipe.tasks.core.proto; package mediapipe.tasks.core.proto;
import "mediapipe/gpu/gpu_origin.proto";
import "mediapipe/tasks/cc/core/proto/acceleration.proto"; import "mediapipe/tasks/cc/core/proto/acceleration.proto";
import "mediapipe/tasks/cc/core/proto/external_file.proto"; import "mediapipe/tasks/cc/core/proto/external_file.proto";
@ -24,7 +25,7 @@ option java_package = "com.google.mediapipe.tasks.core.proto";
option java_outer_classname = "BaseOptionsProto"; option java_outer_classname = "BaseOptionsProto";
// Base options for mediapipe tasks. // Base options for mediapipe tasks.
// Next Id: 4 // Next Id: 5
message BaseOptions { message BaseOptions {
// The external model asset, as a single standalone TFLite file. It could be // The external model asset, as a single standalone TFLite file. It could be
// packed with TFLite Model Metadata[1] and associated files if exist. Fail to // packed with TFLite Model Metadata[1] and associated files if exist. Fail to
@ -40,4 +41,7 @@ message BaseOptions {
// Acceleration setting to use available delegate on the device. // Acceleration setting to use available delegate on the device.
optional Acceleration acceleration = 3; optional Acceleration acceleration = 3;
// Gpu origin for calculators with gpu supported.
optional mediapipe.GpuOrigin.Mode gpu_origin = 4 [default = TOP_LEFT];
} }

View File

@ -225,7 +225,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration()); subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, subgraph_options.base_options().gpu_origin(),
&preprocessing.GetOptions< &preprocessing.GetOptions<
components::processors::proto::ImagePreprocessingGraphOptions>())); components::processors::proto::ImagePreprocessingGraphOptions>()));
auto& image_to_tensor_options = auto& image_to_tensor_options =

View File

@ -134,6 +134,9 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->CopyFrom(options->base_options().acceleration()); ->CopyFrom(options->base_options().acceleration());
face_detector_graph_options->mutable_base_options()->set_use_stream_mode( face_detector_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode()); options->base_options().use_stream_mode());
face_detector_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
auto* face_landmarks_detector_graph_options = auto* face_landmarks_detector_graph_options =
options->mutable_face_landmarks_detector_graph_options(); options->mutable_face_landmarks_detector_graph_options();
if (!face_landmarks_detector_graph_options->base_options() if (!face_landmarks_detector_graph_options->base_options()
@ -151,6 +154,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->CopyFrom(options->base_options().acceleration()); ->CopyFrom(options->base_options().acceleration());
face_landmarks_detector_graph_options->mutable_base_options() face_landmarks_detector_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());
face_landmarks_detector_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
absl::StatusOr<absl::string_view> face_blendshape_model = absl::StatusOr<absl::string_view> face_blendshape_model =
resources.GetFile(kFaceBlendshapeTFLiteName); resources.GetFile(kFaceBlendshapeTFLiteName);

View File

@ -266,7 +266,7 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration()); subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, subgraph_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -395,7 +395,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration()); task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
*model_resources, use_gpu, *model_resources, use_gpu, task_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
auto& image_to_tensor_options = auto& image_to_tensor_options =

View File

@ -131,6 +131,11 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
} }
hand_gesture_recognizer_graph_options->mutable_base_options() hand_gesture_recognizer_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());
hand_landmarker_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
hand_gesture_recognizer_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -242,7 +242,7 @@ class HandDetectorGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration()); subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, subgraph_options.base_options().gpu_origin(),
&preprocessing.GetOptions< &preprocessing.GetOptions<
components::processors::proto::ImagePreprocessingGraphOptions>())); components::processors::proto::ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In("IMAGE");

View File

@ -125,6 +125,11 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->CopyFrom(options->base_options().acceleration()); ->CopyFrom(options->base_options().acceleration());
hand_landmarks_detector_graph_options->mutable_base_options() hand_landmarks_detector_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());
hand_detector_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
hand_landmarks_detector_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
return absl::OkStatus(); return absl::OkStatus();
} }
} // namespace } // namespace

View File

@ -256,7 +256,7 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration()); subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, subgraph_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In("IMAGE");

View File

@ -142,7 +142,7 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration()); task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, task_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -137,7 +137,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration()); task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, task_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -255,8 +255,8 @@ void ConfigureTensorConverterCalculator(
// the tflite model. // the tflite model.
absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors( absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
Source<Image> image_in, Source<NormalizedRect> norm_rect_in, bool use_gpu, Source<Image> image_in, Source<NormalizedRect> norm_rect_in, bool use_gpu,
bool is_hair_segmentation, const core::ModelResources& model_resources, const core::proto::BaseOptions& base_options, bool is_hair_segmentation,
Graph& graph) { const core::ModelResources& model_resources, Graph& graph) {
MP_ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor, MP_ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor,
GetInputTensor(model_resources)); GetInputTensor(model_resources));
if (tflite_input_tensor->shape()->size() != 4) { if (tflite_input_tensor->shape()->size() != 4) {
@ -279,7 +279,7 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
auto& preprocessing = graph.AddNode( auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph"); "mediapipe.tasks.components.processors.ImagePreprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, base_options.gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);
@ -518,7 +518,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
MP_ASSIGN_OR_RETURN( MP_ASSIGN_OR_RETURN(
auto image_and_tensors, auto image_and_tensors,
ConvertImageToTensors(image_in, norm_rect_in, use_gpu, ConvertImageToTensors(image_in, norm_rect_in, use_gpu,
is_hair_segmentation, model_resources, graph)); task_options.base_options(), is_hair_segmentation,
model_resources, graph));
// Adds inference subgraph and connects its input stream to the output // Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator. // tensors produced by the ImageToTensorCalculator.
auto& inference = AddInference( auto& inference = AddInference(

View File

@ -192,7 +192,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration()); task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, task_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -240,7 +240,7 @@ class PoseDetectorGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration()); subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, subgraph_options.base_options().gpu_origin(),
&preprocessing.GetOptions< &preprocessing.GetOptions<
components::processors::proto::ImagePreprocessingGraphOptions>())); components::processors::proto::ImagePreprocessingGraphOptions>()));
auto& image_to_tensor_options = auto& image_to_tensor_options =

View File

@ -134,6 +134,11 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
pose_landmarks_detector_graph_options->mutable_base_options() pose_landmarks_detector_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());
pose_detector_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
pose_landmarks_detector_graph_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -347,7 +347,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration()); subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu, model_resources, use_gpu, subgraph_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto:: &preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>())); ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -66,6 +66,7 @@ mediapipe_ts_library(
":task_runner", ":task_runner",
":task_runner_test_utils", ":task_runner_test_utils",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto", "//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/gpu:gpu_origin_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],

View File

@ -17,6 +17,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {GpuOrigin as GpuOriginProto} from '../../../gpu/gpu_origin_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
@ -111,6 +112,7 @@ describe('TaskRunner', () => {
tflite: {}, tflite: {},
nnapi: undefined, nnapi: undefined,
}, },
gpuOrigin: GpuOriginProto.Mode.TOP_LEFT,
}; };
const mockBytesResultWithGpuDelegate = { const mockBytesResultWithGpuDelegate = {
...mockBytesResult, ...mockBytesResult,
@ -146,6 +148,7 @@ describe('TaskRunner', () => {
tflite: {}, tflite: {},
nnapi: undefined, nnapi: undefined,
}, },
gpuOrigin: GpuOriginProto.Mode.TOP_LEFT,
}; };
let fetchSpy: jasmine.Spy; let fetchSpy: jasmine.Spy;