Update base_options.py
This commit is contained in:
parent
f63baaf8d2
commit
6100f0e76e
|
@ -16,6 +16,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from mediapipe.calculators.tensor import inference_calculator_pb2
|
from mediapipe.calculators.tensor import inference_calculator_pb2
|
||||||
|
@ -63,10 +64,22 @@ class BaseOptions:
|
||||||
else:
|
else:
|
||||||
full_path = None
|
full_path = None
|
||||||
|
|
||||||
if self.delegate == BaseOptions.Delegate.GPU:
|
platform = platform.system()
|
||||||
acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu())
|
|
||||||
|
if self.delegate is not None:
|
||||||
|
if platform == "Linux":
|
||||||
|
if self.delegate == BaseOptions.Delegate.GPU:
|
||||||
|
acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu())
|
||||||
|
else:
|
||||||
|
acceleration_proto = _AccelerationProto(tflite=_DelegateProto.TfLite())
|
||||||
|
elif platform == "Windows":
|
||||||
|
raise Exception("Delegate is unsupported for Windows.")
|
||||||
|
elif platform == "Darwin":
|
||||||
|
raise Exception("Delegate is unsupported for MacOS.")
|
||||||
|
else:
|
||||||
|
raise Exception("Unidentified system")
|
||||||
else:
|
else:
|
||||||
acceleration_proto = _AccelerationProto(tflite=_DelegateProto.TfLite())
|
acceleration_proto = None
|
||||||
|
|
||||||
return _BaseOptionsProto(
|
return _BaseOptionsProto(
|
||||||
model_asset=_ExternalFileProto(
|
model_asset=_ExternalFileProto(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user