Merge pull request #4430 from kinaryml:python-gpu-support

PiperOrigin-RevId: 535133616
This commit is contained in:
Copybara-Service 2023-05-24 23:53:00 -07:00
commit 034caf3d87
3 changed files with 62 additions and 6 deletions

View File

@ -34,6 +34,8 @@ py_library(
], ],
deps = [ deps = [
":optional_dependencies", ":optional_dependencies",
"//mediapipe/calculators/tensor:inference_calculator_py_pb2",
"//mediapipe/tasks/cc/core/proto:acceleration_py_pb2",
"//mediapipe/tasks/cc/core/proto:base_options_py_pb2", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2",
"//mediapipe/tasks/cc/core/proto:external_file_py_pb2", "//mediapipe/tasks/cc/core/proto:external_file_py_pb2",
], ],

View File

@ -14,13 +14,19 @@
"""Base options for MediaPipe Task APIs.""" """Base options for MediaPipe Task APIs."""
import dataclasses import dataclasses
import enum
import os import os
import platform
from typing import Any, Optional from typing import Any, Optional
from mediapipe.calculators.tensor import inference_calculator_pb2
from mediapipe.tasks.cc.core.proto import acceleration_pb2
from mediapipe.tasks.cc.core.proto import base_options_pb2 from mediapipe.tasks.cc.core.proto import base_options_pb2
from mediapipe.tasks.cc.core.proto import external_file_pb2 from mediapipe.tasks.cc.core.proto import external_file_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_DelegateProto = inference_calculator_pb2.InferenceCalculatorOptions.Delegate
_AccelerationProto = acceleration_pb2.Acceleration
_BaseOptionsProto = base_options_pb2.BaseOptions _BaseOptionsProto = base_options_pb2.BaseOptions
_ExternalFileProto = external_file_pb2.ExternalFile _ExternalFileProto = external_file_pb2.ExternalFile
@ -41,11 +47,17 @@ class BaseOptions:
Attributes: Attributes:
model_asset_path: Path to the model asset file. model_asset_path: Path to the model asset file.
model_asset_buffer: The model asset file contents as bytes. model_asset_buffer: The model asset file contents as bytes.
delegate: Accelaration to use. Supported values are GPU and CPU. GPU support
is currently limited to Ubuntu platforms.
""" """
class Delegate(enum.Enum):
CPU = 0
GPU = 1
model_asset_path: Optional[str] = None model_asset_path: Optional[str] = None
model_asset_buffer: Optional[bytes] = None model_asset_buffer: Optional[bytes] = None
# TODO: Allow Python API to specify acceleration settings. delegate: Optional[Delegate] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _BaseOptionsProto: def to_pb2(self) -> _BaseOptionsProto:
@ -55,17 +67,44 @@ class BaseOptions:
else: else:
full_path = None full_path = None
platform_name = platform.system()
if self.delegate == BaseOptions.Delegate.GPU:
if platform_name == 'Linux':
acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu())
else:
raise NotImplementedError(
'GPU Delegate is not yet supported for ' + platform_name
)
elif self.delegate == BaseOptions.Delegate.CPU:
acceleration_proto = _AccelerationProto(tflite=_DelegateProto.TfLite())
else:
acceleration_proto = None
return _BaseOptionsProto( return _BaseOptionsProto(
model_asset=_ExternalFileProto( model_asset=_ExternalFileProto(
file_name=full_path, file_content=self.model_asset_buffer)) file_name=full_path, file_content=self.model_asset_buffer
),
acceleration=acceleration_proto,
)
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _BaseOptionsProto) -> 'BaseOptions': def create_from_pb2(cls, pb2_obj: _BaseOptionsProto) -> 'BaseOptions':
"""Creates a `BaseOptions` object from the given protobuf object.""" """Creates a `BaseOptions` object from the given protobuf object."""
delegate = None
if pb2_obj.acceleration is not None:
delegate = (
BaseOptions.Delegate.GPU
if pb2_obj.acceleration.gpu is not None
else BaseOptions.Delegate.CPU
)
return BaseOptions( return BaseOptions(
model_asset_path=pb2_obj.model_asset.file_name, model_asset_path=pb2_obj.model_asset.file_name,
model_asset_buffer=pb2_obj.model_asset.file_content) model_asset_buffer=pb2_obj.model_asset.file_content,
delegate=delegate,
)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object. """Checks if this object is equal to the given object.

View File

@ -30,6 +30,7 @@ from setuptools.command import build_py
from setuptools.command import install from setuptools.command import install
__version__ = 'dev' __version__ = 'dev'
MP_DISABLE_GPU = os.environ.get('MEDIAPIPE_DISABLE_GPU') != '0'
IS_WINDOWS = (platform.system() == 'Windows') IS_WINDOWS = (platform.system() == 'Windows')
IS_MAC = (platform.system() == 'Darwin') IS_MAC = (platform.system() == 'Darwin')
MP_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) MP_ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
@ -279,10 +280,16 @@ class BuildModules(build_ext.build_ext):
'build', 'build',
'--compilation_mode=opt', '--compilation_mode=opt',
'--copt=-DNDEBUG', '--copt=-DNDEBUG',
'--define=MEDIAPIPE_DISABLE_GPU=1',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
binary_graph_target, binary_graph_target,
] ]
if MP_DISABLE_GPU:
bazel_command.append('--define=MEDIAPIPE_DISABLE_GPU=1')
else:
bazel_command.append('--copt=-DMESA_EGL_NO_X11_HEADERS')
bazel_command.append('--copt=-DEGL_NO_X11')
if not self.link_opencv and not IS_WINDOWS: if not self.link_opencv and not IS_WINDOWS:
bazel_command.append('--define=OPENCV=source') bazel_command.append('--define=OPENCV=source')
if subprocess.call(bazel_command) != 0: if subprocess.call(bazel_command) != 0:
@ -300,14 +307,21 @@ class GenerateMetadataSchema(build_ext.build_ext):
'object_detector_metadata_schema_py', 'object_detector_metadata_schema_py',
'schema_py', 'schema_py',
]: ]:
bazel_command = [ bazel_command = [
'bazel', 'bazel',
'build', 'build',
'--compilation_mode=opt', '--compilation_mode=opt',
'--define=MEDIAPIPE_DISABLE_GPU=1',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
'//mediapipe/tasks/metadata:' + target, '//mediapipe/tasks/metadata:' + target,
] ]
if MP_DISABLE_GPU:
bazel_command.append('--define=MEDIAPIPE_DISABLE_GPU=1')
else:
bazel_command.append('--copt=-DMESA_EGL_NO_X11_HEADERS')
bazel_command.append('--copt=-DEGL_NO_X11')
if subprocess.call(bazel_command) != 0: if subprocess.call(bazel_command) != 0:
sys.exit(-1) sys.exit(-1)
_copy_to_build_lib_dir( _copy_to_build_lib_dir(
@ -393,7 +407,8 @@ class BuildExtension(build_ext.build_ext):
'build', 'build',
'--compilation_mode=opt', '--compilation_mode=opt',
'--copt=-DNDEBUG', '--copt=-DNDEBUG',
'--define=MEDIAPIPE_DISABLE_GPU=1', '--copt=-DMESA_EGL_NO_X11_HEADERS',
'--copt=-DEGL_NO_X11',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
str(ext.bazel_target + '.so'), str(ext.bazel_target + '.so'),
] ]