Add delegate options to base options for java API. and add unit tset for BaseOptions.
PiperOrigin-RevId: 544458644
This commit is contained in:
parent
e15d5a797b
commit
0ea54b1461
|
@ -41,9 +41,15 @@ proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
|
|||
proto::Acceleration acceleration_proto = proto::Acceleration();
|
||||
auto* gpu = acceleration_proto.mutable_gpu();
|
||||
gpu->set_use_advanced_gpu_api(true);
|
||||
if (!options.cached_kernel_path.empty()) {
|
||||
gpu->set_cached_kernel_path(options.cached_kernel_path);
|
||||
}
|
||||
if (!options.serialized_model_dir.empty()) {
|
||||
gpu->set_serialized_model_dir(options.serialized_model_dir);
|
||||
}
|
||||
if (!options.model_token.empty()) {
|
||||
gpu->set_model_token(options.model_token);
|
||||
}
|
||||
return acceleration_proto;
|
||||
}
|
||||
|
||||
|
|
|
@ -59,14 +59,15 @@ TEST(DelegateOptionsTest, SucceedGpuOptions) {
|
|||
BaseOptions base_options;
|
||||
base_options.delegate = BaseOptions::Delegate::GPU;
|
||||
BaseOptions::GpuOptions gpu_options;
|
||||
gpu_options.cached_kernel_path = kCachedModelDir;
|
||||
gpu_options.serialized_model_dir = kCachedModelDir;
|
||||
gpu_options.model_token = kModelToken;
|
||||
base_options.delegate_options = gpu_options;
|
||||
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
|
||||
ASSERT_TRUE(proto.acceleration().has_gpu());
|
||||
ASSERT_FALSE(proto.acceleration().has_tflite());
|
||||
EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api());
|
||||
EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir);
|
||||
EXPECT_FALSE(proto.acceleration().gpu().has_cached_kernel_path());
|
||||
EXPECT_EQ(proto.acceleration().gpu().serialized_model_dir(), kCachedModelDir);
|
||||
EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken);
|
||||
}
|
||||
|
||||
|
|
|
@ -54,6 +54,9 @@ public abstract class BaseOptions {
|
|||
*/
|
||||
public abstract Builder setDelegate(Delegate delegate);
|
||||
|
||||
/** Options for the chosen delegate. If not set, the default delegate options is used. */
|
||||
public abstract Builder setDelegateOptions(DelegateOptions delegateOptions);
|
||||
|
||||
abstract BaseOptions autoBuild();
|
||||
|
||||
/**
|
||||
|
@ -79,6 +82,23 @@ public abstract class BaseOptions {
|
|||
throw new IllegalArgumentException(
|
||||
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
|
||||
}
|
||||
boolean delegateMatchesDelegateOptions = true;
|
||||
if (options.delegateOptions().isPresent()) {
|
||||
switch (options.delegate()) {
|
||||
case CPU:
|
||||
delegateMatchesDelegateOptions =
|
||||
options.delegateOptions().get() instanceof DelegateOptions.CpuOptions;
|
||||
break;
|
||||
case GPU:
|
||||
delegateMatchesDelegateOptions =
|
||||
options.delegateOptions().get() instanceof DelegateOptions.GpuOptions;
|
||||
break;
|
||||
}
|
||||
if (!delegateMatchesDelegateOptions) {
|
||||
throw new IllegalArgumentException(
|
||||
"Specified Delegate type does not match the provided delegate options.");
|
||||
}
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +111,67 @@ public abstract class BaseOptions {
|
|||
|
||||
abstract Delegate delegate();
|
||||
|
||||
abstract Optional<DelegateOptions> delegateOptions();
|
||||
|
||||
/** Advanced config options for the used delegate. */
|
||||
public abstract static class DelegateOptions {
|
||||
|
||||
/** Options for CPU. */
|
||||
@AutoValue
|
||||
public abstract static class CpuOptions extends DelegateOptions {
|
||||
|
||||
public static Builder builder() {
|
||||
Builder builder = new AutoValue_BaseOptions_DelegateOptions_CpuOptions.Builder();
|
||||
return builder;
|
||||
}
|
||||
|
||||
/** Builder for {@link CpuOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
public abstract CpuOptions build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Options for GPU. */
|
||||
@AutoValue
|
||||
public abstract static class GpuOptions extends DelegateOptions {
|
||||
// Load pre-compiled serialized binary cache to accelerate init process.
|
||||
// Only available on Android. Kernel caching will only be enabled if this
|
||||
// path is set. NOTE: binary cache usage may be skipped if valid serialized
|
||||
// model, specified by "serialized_model_dir", exists.
|
||||
abstract Optional<String> cachedKernelPath();
|
||||
|
||||
// A dir to load from and save to a pre-compiled serialized model used to
|
||||
// accelerate init process.
|
||||
// NOTE: serialized model takes precedence over binary cache
|
||||
// specified by "cached_kernel_path", which still can be used if
|
||||
// serialized model is invalid or missing.
|
||||
abstract Optional<String> serializedModelDir();
|
||||
|
||||
// Unique token identifying the model. Used in conjunction with
|
||||
// "serialized_model_dir". It is the caller's responsibility to ensure
|
||||
// there is no clash of the tokens.
|
||||
abstract Optional<String> modelToken();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_BaseOptions_DelegateOptions_GpuOptions.Builder();
|
||||
}
|
||||
|
||||
/** Builder for {@link GpuOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
public abstract Builder setCachedKernelPath(String cachedKernelPath);
|
||||
|
||||
public abstract Builder setSerializedModelDir(String serializedModelDir);
|
||||
|
||||
public abstract Builder setModelToken(String modelToken);
|
||||
|
||||
public abstract GpuOptions build();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU);
|
||||
}
|
||||
|
|
|
@ -61,17 +61,51 @@ public abstract class TaskOptions {
|
|||
accelerationBuilder.setTflite(
|
||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite
|
||||
.getDefaultInstance());
|
||||
options
|
||||
.delegateOptions()
|
||||
.ifPresent(
|
||||
delegateOptions ->
|
||||
setDelegateOptions(
|
||||
accelerationBuilder,
|
||||
(BaseOptions.DelegateOptions.CpuOptions) delegateOptions));
|
||||
break;
|
||||
case GPU:
|
||||
accelerationBuilder.setGpu(
|
||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.newBuilder()
|
||||
.setUseAdvancedGpuApi(true)
|
||||
.build());
|
||||
options
|
||||
.delegateOptions()
|
||||
.ifPresent(
|
||||
delegateOptions ->
|
||||
setDelegateOptions(
|
||||
accelerationBuilder,
|
||||
(BaseOptions.DelegateOptions.GpuOptions) delegateOptions));
|
||||
break;
|
||||
}
|
||||
|
||||
return BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setModelAsset(externalFileBuilder.build())
|
||||
.setAcceleration(accelerationBuilder.build())
|
||||
.build();
|
||||
}
|
||||
|
||||
private void setDelegateOptions(
|
||||
AccelerationProto.Acceleration.Builder accelerationBuilder,
|
||||
BaseOptions.DelegateOptions.CpuOptions options) {
|
||||
accelerationBuilder.setTflite(
|
||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite.getDefaultInstance());
|
||||
}
|
||||
|
||||
private void setDelegateOptions(
|
||||
AccelerationProto.Acceleration.Builder accelerationBuilder,
|
||||
BaseOptions.DelegateOptions.GpuOptions options) {
|
||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.Builder gpuBuilder =
|
||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.newBuilder()
|
||||
.setUseAdvancedGpuApi(true);
|
||||
options.cachedKernelPath().ifPresent(gpuBuilder::setCachedKernelPath);
|
||||
options.serializedModelDir().ifPresent(gpuBuilder::setSerializedModelDir);
|
||||
options.modelToken().ifPresent(gpuBuilder::setModelToken);
|
||||
accelerationBuilder.setGpu(gpuBuilder.build());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.coretest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="coretest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.coretest" />
|
||||
|
||||
</manifest>
|
|
@ -23,3 +23,5 @@ android_library(
|
|||
"//third_party/java/android_libs/guava_jdk5:io",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Enable this in OSS
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.tasks.core.proto.AccelerationProto;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link BaseOptions} */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({BaseOptionsTest.General.class, BaseOptionsTest.ConvertProtoTest.class})
|
||||
public class BaseOptionsTest {
|
||||
|
||||
static final String MODEL_ASSET_PATH = "dummy_model.tflite";
|
||||
static final String SERIALIZED_MODEL_DIR = "dummy_serialized_model_dir";
|
||||
static final String MODEL_TOKEN = "dummy_model_token";
|
||||
static final String CACHED_KERNEL_PATH = "dummy_cached_kernel_path";
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends BaseOptionsTest {
|
||||
@Test
|
||||
public void succeedsWithDefaultOptions() throws Exception {
|
||||
BaseOptions options = BaseOptions.builder().setModelAssetPath(MODEL_ASSET_PATH).build();
|
||||
assertThat(options.modelAssetPath().isPresent()).isTrue();
|
||||
assertThat(options.modelAssetPath().get()).isEqualTo(MODEL_ASSET_PATH);
|
||||
assertThat(options.delegate()).isEqualTo(Delegate.CPU);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void succeedsWithGpuOptions() throws Exception {
|
||||
BaseOptions options =
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(MODEL_ASSET_PATH)
|
||||
.setDelegate(Delegate.GPU)
|
||||
.setDelegateOptions(
|
||||
BaseOptions.DelegateOptions.GpuOptions.builder()
|
||||
.setSerializedModelDir(SERIALIZED_MODEL_DIR)
|
||||
.setModelToken(MODEL_TOKEN)
|
||||
.setCachedKernelPath(CACHED_KERNEL_PATH)
|
||||
.build())
|
||||
.build();
|
||||
assertThat(
|
||||
((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get())
|
||||
.serializedModelDir()
|
||||
.get())
|
||||
.isEqualTo(SERIALIZED_MODEL_DIR);
|
||||
assertThat(
|
||||
((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get())
|
||||
.modelToken()
|
||||
.get())
|
||||
.isEqualTo(MODEL_TOKEN);
|
||||
assertThat(
|
||||
((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get())
|
||||
.cachedKernelPath()
|
||||
.get())
|
||||
.isEqualTo(CACHED_KERNEL_PATH);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failsWithInvalidDelegateOptions() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(MODEL_ASSET_PATH)
|
||||
.setDelegate(Delegate.CPU)
|
||||
.setDelegateOptions(
|
||||
BaseOptions.DelegateOptions.GpuOptions.builder()
|
||||
.setSerializedModelDir(SERIALIZED_MODEL_DIR)
|
||||
.setModelToken(MODEL_TOKEN)
|
||||
.build())
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("Specified Delegate type does not match the provided delegate options.");
|
||||
}
|
||||
}
|
||||
|
||||
/** A mock TaskOptions class providing access to convertBaseOptionsToProto. */
|
||||
public static class MockTaskOptions extends TaskOptions {
|
||||
|
||||
public MockTaskOptions(BaseOptions baseOptions) {
|
||||
baseOptionsProto = convertBaseOptionsToProto(baseOptions);
|
||||
}
|
||||
|
||||
public BaseOptionsProto.BaseOptions getBaseOptionsProto() {
|
||||
return baseOptionsProto;
|
||||
}
|
||||
|
||||
private BaseOptionsProto.BaseOptions baseOptionsProto;
|
||||
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
return CalculatorOptions.newBuilder().build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Test for converting {@link BaseOptions} to {@link BaseOptionsProto} */
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class ConvertProtoTest extends BaseOptionsTest {
|
||||
@Test
|
||||
public void succeedsWithDefaultOptions() throws Exception {
|
||||
BaseOptions options =
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(MODEL_ASSET_PATH)
|
||||
.setDelegate(Delegate.CPU)
|
||||
.setDelegateOptions(BaseOptions.DelegateOptions.CpuOptions.builder().build())
|
||||
.build();
|
||||
MockTaskOptions taskOptions = new MockTaskOptions(options);
|
||||
AccelerationProto.Acceleration acceleration =
|
||||
taskOptions.getBaseOptionsProto().getAcceleration();
|
||||
assertThat(acceleration.hasTflite()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void succeedsWithGpuOptions() throws Exception {
|
||||
BaseOptions options =
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(MODEL_ASSET_PATH)
|
||||
.setDelegate(Delegate.GPU)
|
||||
.setDelegateOptions(
|
||||
BaseOptions.DelegateOptions.GpuOptions.builder()
|
||||
.setModelToken(MODEL_TOKEN)
|
||||
.setSerializedModelDir(SERIALIZED_MODEL_DIR)
|
||||
.build())
|
||||
.build();
|
||||
MockTaskOptions taskOptions = new MockTaskOptions(options);
|
||||
AccelerationProto.Acceleration acceleration =
|
||||
taskOptions.getBaseOptionsProto().getAcceleration();
|
||||
assertThat(acceleration.hasTflite()).isFalse();
|
||||
assertThat(acceleration.hasGpu()).isTrue();
|
||||
assertThat(acceleration.getGpu().getUseAdvancedGpuApi()).isTrue();
|
||||
assertThat(acceleration.getGpu().hasCachedKernelPath()).isFalse();
|
||||
assertThat(acceleration.getGpu().getModelToken()).isEqualTo(MODEL_TOKEN);
|
||||
assertThat(acceleration.getGpu().getSerializedModelDir()).isEqualTo(SERIALIZED_MODEL_DIR);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user