diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index d04fc4258..eb658c0e2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -32,6 +32,7 @@ android_library( "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + "//third_party:any_java_proto", "//third_party:autovalue", "@com_google_protobuf//:protobuf_javalite", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 3c422a8b2..ad3d01119 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -20,6 +20,8 @@ import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig.Node; import com.google.mediapipe.proto.CalculatorProto.InputStreamInfo; import com.google.mediapipe.calculator.proto.FlowLimiterCalculatorProto.FlowLimiterCalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.protobuf.Any; import java.util.ArrayList; import java.util.List; @@ -110,10 +112,21 @@ public abstract class TaskInfo { */ CalculatorGraphConfig generateGraphConfig() { CalculatorGraphConfig.Builder graphBuilder = CalculatorGraphConfig.newBuilder(); - Node.Builder taskSubgraphBuilder = - Node.newBuilder() - .setCalculator(taskGraphName()) - .setOptions(taskOptions().convertToCalculatorOptionsProto()); + CalculatorOptions options = taskOptions().convertToCalculatorOptionsProto(); + Any anyOptions = taskOptions().convertToAnyProto(); + if (!(options == null ^ anyOptions == null)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Only one of convertTo*Proto() method should be implemented for " + + taskOptions().getClass()); + } + Node.Builder taskSubgraphBuilder = Node.newBuilder().setCalculator(taskGraphName()); + if (options != null) { + taskSubgraphBuilder.setOptions(options); + } + if (anyOptions != null) { + taskSubgraphBuilder.addNodeOptions(anyOptions); + } for (String outputStream : outputStreams()) { taskSubgraphBuilder.addOutputStream(outputStream); graphBuilder.addOutputStream(outputStream); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java index 991acebaf..4ca258429 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java @@ -20,18 +20,26 @@ import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.tasks.core.proto.AccelerationProto; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.core.proto.ExternalFileProto; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; /** * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend - * {@link TaskOptions}. + * {@link TaskOptions} and implement exactly one of converTo*Proto() methods. */ public abstract class TaskOptions { /** * Converts a MediaPipe Tasks task-specific options to a {@link CalculatorOptions} protobuf * message. */ - public abstract CalculatorOptions convertToCalculatorOptionsProto(); + public CalculatorOptions convertToCalculatorOptionsProto() { + return null; + } + + /** Converts a MediaPipe Tasks task-specific options to an proto3 {@link Any} message. */ + public Any convertToAnyProto() { + return null; + } /** * Converts a {@link BaseOptions} instance to a {@link BaseOptionsProto.BaseOptions} protobuf diff --git a/third_party/BUILD b/third_party/BUILD index 470b7ff99..c1bee7a6e 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -378,3 +378,10 @@ java_library( "@maven//:com_google_auto_value_auto_value_annotations", ], ) + +java_proto_library( + name = "any_java_proto", + deps = [ + "@com_google_protobuf//:any_proto", + ], +)