Add delegate options to base options for java API. and add unit tset for BaseOptions.

PiperOrigin-RevId: 544458644
This commit is contained in:
MediaPipe Team 2023-06-29 14:11:29 -07:00 committed by Copybara-Service
parent e15d5a797b
commit 0ea54b1461
7 changed files with 312 additions and 5 deletions

View File

@ -41,9 +41,15 @@ proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
proto::Acceleration acceleration_proto = proto::Acceleration(); proto::Acceleration acceleration_proto = proto::Acceleration();
auto* gpu = acceleration_proto.mutable_gpu(); auto* gpu = acceleration_proto.mutable_gpu();
gpu->set_use_advanced_gpu_api(true); gpu->set_use_advanced_gpu_api(true);
if (!options.cached_kernel_path.empty()) {
gpu->set_cached_kernel_path(options.cached_kernel_path); gpu->set_cached_kernel_path(options.cached_kernel_path);
}
if (!options.serialized_model_dir.empty()) {
gpu->set_serialized_model_dir(options.serialized_model_dir); gpu->set_serialized_model_dir(options.serialized_model_dir);
}
if (!options.model_token.empty()) {
gpu->set_model_token(options.model_token); gpu->set_model_token(options.model_token);
}
return acceleration_proto; return acceleration_proto;
} }

View File

@ -59,14 +59,15 @@ TEST(DelegateOptionsTest, SucceedGpuOptions) {
BaseOptions base_options; BaseOptions base_options;
base_options.delegate = BaseOptions::Delegate::GPU; base_options.delegate = BaseOptions::Delegate::GPU;
BaseOptions::GpuOptions gpu_options; BaseOptions::GpuOptions gpu_options;
gpu_options.cached_kernel_path = kCachedModelDir; gpu_options.serialized_model_dir = kCachedModelDir;
gpu_options.model_token = kModelToken; gpu_options.model_token = kModelToken;
base_options.delegate_options = gpu_options; base_options.delegate_options = gpu_options;
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
ASSERT_TRUE(proto.acceleration().has_gpu()); ASSERT_TRUE(proto.acceleration().has_gpu());
ASSERT_FALSE(proto.acceleration().has_tflite()); ASSERT_FALSE(proto.acceleration().has_tflite());
EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api()); EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api());
EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir); EXPECT_FALSE(proto.acceleration().gpu().has_cached_kernel_path());
EXPECT_EQ(proto.acceleration().gpu().serialized_model_dir(), kCachedModelDir);
EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken); EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken);
} }

View File

@ -54,6 +54,9 @@ public abstract class BaseOptions {
*/ */
public abstract Builder setDelegate(Delegate delegate); public abstract Builder setDelegate(Delegate delegate);
/** Options for the chosen delegate. If not set, the default delegate options is used. */
public abstract Builder setDelegateOptions(DelegateOptions delegateOptions);
abstract BaseOptions autoBuild(); abstract BaseOptions autoBuild();
/** /**
@ -79,6 +82,23 @@ public abstract class BaseOptions {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
} }
boolean delegateMatchesDelegateOptions = true;
if (options.delegateOptions().isPresent()) {
switch (options.delegate()) {
case CPU:
delegateMatchesDelegateOptions =
options.delegateOptions().get() instanceof DelegateOptions.CpuOptions;
break;
case GPU:
delegateMatchesDelegateOptions =
options.delegateOptions().get() instanceof DelegateOptions.GpuOptions;
break;
}
if (!delegateMatchesDelegateOptions) {
throw new IllegalArgumentException(
"Specified Delegate type does not match the provided delegate options.");
}
}
return options; return options;
} }
} }
@ -91,6 +111,67 @@ public abstract class BaseOptions {
abstract Delegate delegate(); abstract Delegate delegate();
abstract Optional<DelegateOptions> delegateOptions();
/** Advanced config options for the used delegate. */
public abstract static class DelegateOptions {
/** Options for CPU. */
@AutoValue
public abstract static class CpuOptions extends DelegateOptions {
public static Builder builder() {
Builder builder = new AutoValue_BaseOptions_DelegateOptions_CpuOptions.Builder();
return builder;
}
/** Builder for {@link CpuOptions}. */
@AutoValue.Builder
public abstract static class Builder {
public abstract CpuOptions build();
}
}
/** Options for GPU. */
@AutoValue
public abstract static class GpuOptions extends DelegateOptions {
// Load pre-compiled serialized binary cache to accelerate init process.
// Only available on Android. Kernel caching will only be enabled if this
// path is set. NOTE: binary cache usage may be skipped if valid serialized
// model, specified by "serialized_model_dir", exists.
abstract Optional<String> cachedKernelPath();
// A dir to load from and save to a pre-compiled serialized model used to
// accelerate init process.
// NOTE: serialized model takes precedence over binary cache
// specified by "cached_kernel_path", which still can be used if
// serialized model is invalid or missing.
abstract Optional<String> serializedModelDir();
// Unique token identifying the model. Used in conjunction with
// "serialized_model_dir". It is the caller's responsibility to ensure
// there is no clash of the tokens.
abstract Optional<String> modelToken();
public static Builder builder() {
return new AutoValue_BaseOptions_DelegateOptions_GpuOptions.Builder();
}
/** Builder for {@link GpuOptions}. */
@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setCachedKernelPath(String cachedKernelPath);
public abstract Builder setSerializedModelDir(String serializedModelDir);
public abstract Builder setModelToken(String modelToken);
public abstract GpuOptions build();
}
}
}
public static Builder builder() { public static Builder builder() {
return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU); return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU);
} }

View File

@ -61,17 +61,51 @@ public abstract class TaskOptions {
accelerationBuilder.setTflite( accelerationBuilder.setTflite(
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite
.getDefaultInstance()); .getDefaultInstance());
options
.delegateOptions()
.ifPresent(
delegateOptions ->
setDelegateOptions(
accelerationBuilder,
(BaseOptions.DelegateOptions.CpuOptions) delegateOptions));
break; break;
case GPU: case GPU:
accelerationBuilder.setGpu( accelerationBuilder.setGpu(
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.newBuilder() InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.newBuilder()
.setUseAdvancedGpuApi(true) .setUseAdvancedGpuApi(true)
.build()); .build());
options
.delegateOptions()
.ifPresent(
delegateOptions ->
setDelegateOptions(
accelerationBuilder,
(BaseOptions.DelegateOptions.GpuOptions) delegateOptions));
break; break;
} }
return BaseOptionsProto.BaseOptions.newBuilder() return BaseOptionsProto.BaseOptions.newBuilder()
.setModelAsset(externalFileBuilder.build()) .setModelAsset(externalFileBuilder.build())
.setAcceleration(accelerationBuilder.build()) .setAcceleration(accelerationBuilder.build())
.build(); .build();
} }
private void setDelegateOptions(
AccelerationProto.Acceleration.Builder accelerationBuilder,
BaseOptions.DelegateOptions.CpuOptions options) {
accelerationBuilder.setTflite(
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite.getDefaultInstance());
}
private void setDelegateOptions(
AccelerationProto.Acceleration.Builder accelerationBuilder,
BaseOptions.DelegateOptions.GpuOptions options) {
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.Builder gpuBuilder =
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.newBuilder()
.setUseAdvancedGpuApi(true);
options.cachedKernelPath().ifPresent(gpuBuilder::setCachedKernelPath);
options.serializedModelDir().ifPresent(gpuBuilder::setSerializedModelDir);
options.modelToken().ifPresent(gpuBuilder::setModelToken);
accelerationBuilder.setGpu(gpuBuilder.build());
}
} }

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.coretest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="coretest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.coretest" />
</manifest>

View File

@ -23,3 +23,5 @@ android_library(
"//third_party/java/android_libs/guava_jdk5:io", "//third_party/java/android_libs/guava_jdk5:io",
], ],
) )
# TODO: Enable this in OSS

View File

@ -0,0 +1,159 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.tasks.core.proto.AccelerationProto;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link BaseOptions} */
@RunWith(Suite.class)
@SuiteClasses({BaseOptionsTest.General.class, BaseOptionsTest.ConvertProtoTest.class})
public class BaseOptionsTest {
static final String MODEL_ASSET_PATH = "dummy_model.tflite";
static final String SERIALIZED_MODEL_DIR = "dummy_serialized_model_dir";
static final String MODEL_TOKEN = "dummy_model_token";
static final String CACHED_KERNEL_PATH = "dummy_cached_kernel_path";
@RunWith(AndroidJUnit4.class)
public static final class General extends BaseOptionsTest {
@Test
public void succeedsWithDefaultOptions() throws Exception {
BaseOptions options = BaseOptions.builder().setModelAssetPath(MODEL_ASSET_PATH).build();
assertThat(options.modelAssetPath().isPresent()).isTrue();
assertThat(options.modelAssetPath().get()).isEqualTo(MODEL_ASSET_PATH);
assertThat(options.delegate()).isEqualTo(Delegate.CPU);
}
@Test
public void succeedsWithGpuOptions() throws Exception {
BaseOptions options =
BaseOptions.builder()
.setModelAssetPath(MODEL_ASSET_PATH)
.setDelegate(Delegate.GPU)
.setDelegateOptions(
BaseOptions.DelegateOptions.GpuOptions.builder()
.setSerializedModelDir(SERIALIZED_MODEL_DIR)
.setModelToken(MODEL_TOKEN)
.setCachedKernelPath(CACHED_KERNEL_PATH)
.build())
.build();
assertThat(
((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get())
.serializedModelDir()
.get())
.isEqualTo(SERIALIZED_MODEL_DIR);
assertThat(
((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get())
.modelToken()
.get())
.isEqualTo(MODEL_TOKEN);
assertThat(
((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get())
.cachedKernelPath()
.get())
.isEqualTo(CACHED_KERNEL_PATH);
}
@Test
public void failsWithInvalidDelegateOptions() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
BaseOptions.builder()
.setModelAssetPath(MODEL_ASSET_PATH)
.setDelegate(Delegate.CPU)
.setDelegateOptions(
BaseOptions.DelegateOptions.GpuOptions.builder()
.setSerializedModelDir(SERIALIZED_MODEL_DIR)
.setModelToken(MODEL_TOKEN)
.build())
.build());
assertThat(exception)
.hasMessageThat()
.contains("Specified Delegate type does not match the provided delegate options.");
}
}
/** A mock TaskOptions class providing access to convertBaseOptionsToProto. */
public static class MockTaskOptions extends TaskOptions {
public MockTaskOptions(BaseOptions baseOptions) {
baseOptionsProto = convertBaseOptionsToProto(baseOptions);
}
public BaseOptionsProto.BaseOptions getBaseOptionsProto() {
return baseOptionsProto;
}
private BaseOptionsProto.BaseOptions baseOptionsProto;
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
return CalculatorOptions.newBuilder().build();
}
}
/** Test for converting {@link BaseOptions} to {@link BaseOptionsProto} */
@RunWith(AndroidJUnit4.class)
public static final class ConvertProtoTest extends BaseOptionsTest {
@Test
public void succeedsWithDefaultOptions() throws Exception {
BaseOptions options =
BaseOptions.builder()
.setModelAssetPath(MODEL_ASSET_PATH)
.setDelegate(Delegate.CPU)
.setDelegateOptions(BaseOptions.DelegateOptions.CpuOptions.builder().build())
.build();
MockTaskOptions taskOptions = new MockTaskOptions(options);
AccelerationProto.Acceleration acceleration =
taskOptions.getBaseOptionsProto().getAcceleration();
assertThat(acceleration.hasTflite()).isTrue();
}
@Test
public void succeedsWithGpuOptions() throws Exception {
BaseOptions options =
BaseOptions.builder()
.setModelAssetPath(MODEL_ASSET_PATH)
.setDelegate(Delegate.GPU)
.setDelegateOptions(
BaseOptions.DelegateOptions.GpuOptions.builder()
.setModelToken(MODEL_TOKEN)
.setSerializedModelDir(SERIALIZED_MODEL_DIR)
.build())
.build();
MockTaskOptions taskOptions = new MockTaskOptions(options);
AccelerationProto.Acceleration acceleration =
taskOptions.getBaseOptionsProto().getAcceleration();
assertThat(acceleration.hasTflite()).isFalse();
assertThat(acceleration.hasGpu()).isTrue();
assertThat(acceleration.getGpu().getUseAdvancedGpuApi()).isTrue();
assertThat(acceleration.getGpu().hasCachedKernelPath()).isFalse();
assertThat(acceleration.getGpu().getModelToken()).isEqualTo(MODEL_TOKEN);
assertThat(acceleration.getGpu().getSerializedModelDir()).isEqualTo(SERIALIZED_MODEL_DIR);
}
}
}