ImageGenerator Java API
PiperOrigin-RevId: 559310074
This commit is contained in:
parent
90781669cb
commit
2ebdb01d43
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imagegenerator">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,660 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagegenerator;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.Bitmap;
|
||||
import android.util.Log;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.ExternalFileProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.facelandmarker.FaceLandmarker.FaceLandmarkerOptions;
|
||||
import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ConditionedImageGraphOptionsProto.ConditionedImageGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ControlPluginGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ImageGeneratorGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions;
|
||||
import com.google.protobuf.Any;
|
||||
import com.google.protobuf.ExtensionRegistryLite;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** Performs image generation from a text prompt. */
|
||||
public final class ImageGenerator extends BaseVisionTaskApi {
|
||||
|
||||
private static final String STEPS_STREAM_NAME = "steps";
|
||||
private static final String ITERATION_STREAM_NAME = "iteration";
|
||||
private static final String PROMPT_STREAM_NAME = "prompt";
|
||||
private static final String RAND_SEED_STREAM_NAME = "rand_seed";
|
||||
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
|
||||
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
|
||||
private static final String SELECT_STREAM_NAME = "select";
|
||||
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
|
||||
private static final int STEPS_OUT_STREAM_INDEX = 1;
|
||||
private static final int ITERATION_OUT_STREAM_INDEX = 2;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
|
||||
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
|
||||
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
|
||||
private static final String TAG = "ImageGenerator";
|
||||
private TaskRunner conditionImageGraphsContainerTaskRunner;
|
||||
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
|
||||
private boolean useConditionImage = false;
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param generatorOptions an {@link ImageGeneratorOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageGenerator} creation.
|
||||
*/
|
||||
public static ImageGenerator createFromOptions(
|
||||
Context context, ImageGeneratorOptions generatorOptions) {
|
||||
return createFromOptions(context, generatorOptions, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageGenerator} instance, from {@link ImageGeneratorOptions} and {@link
|
||||
* ConditionOptions}, if plugin models are used to generate an image based on the condition image.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param generatorOptions an {@link ImageGeneratorOptions} instance.
|
||||
* @param conditionOptions an {@link ConditionOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageGenerator} creation.
|
||||
*/
|
||||
public static ImageGenerator createFromOptions(
|
||||
Context context,
|
||||
ImageGeneratorOptions generatorOptions,
|
||||
@Nullable ConditionOptions conditionOptions) {
|
||||
List<String> inputStreams = new ArrayList<>();
|
||||
inputStreams.addAll(
|
||||
Arrays.asList(
|
||||
"STEPS:" + STEPS_STREAM_NAME,
|
||||
"ITERATION:" + ITERATION_STREAM_NAME,
|
||||
"PROMPT:" + PROMPT_STREAM_NAME,
|
||||
"RAND_SEED:" + RAND_SEED_STREAM_NAME));
|
||||
final boolean useConditionImage = conditionOptions != null;
|
||||
if (useConditionImage) {
|
||||
inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
|
||||
inputStreams.add("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME);
|
||||
generatorOptions.conditionOptions = Optional.of(conditionOptions);
|
||||
}
|
||||
List<String> outputStreams =
|
||||
Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out");
|
||||
|
||||
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ImageGeneratorResult, Void>() {
|
||||
@Override
|
||||
@Nullable
|
||||
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
|
||||
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
|
||||
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
|
||||
Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps);
|
||||
if (iteration != steps - 1) {
|
||||
return null;
|
||||
}
|
||||
Log.i("ImageGenerator", "processing generated image");
|
||||
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
|
||||
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
|
||||
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
|
||||
return ImageGeneratorResult.create(
|
||||
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
handler.setHandleTimestampBoundChanges(true);
|
||||
if (generatorOptions.resultListener().isPresent()) {
|
||||
ResultListener<ImageGeneratorResult, Void> resultListener =
|
||||
new ResultListener<ImageGeneratorResult, Void>() {
|
||||
@Override
|
||||
public void run(ImageGeneratorResult imageGeneratorResult, Void input) {
|
||||
generatorOptions.resultListener().get().run(imageGeneratorResult);
|
||||
}
|
||||
};
|
||||
handler.setResultListener(resultListener);
|
||||
}
|
||||
generatorOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageGeneratorOptions>builder()
|
||||
.setTaskName(ImageGenerator.class.getSimpleName())
|
||||
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(inputStreams)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(generatorOptions)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
handler);
|
||||
ImageGenerator imageGenerator = new ImageGenerator(runner);
|
||||
if (useConditionImage) {
|
||||
imageGenerator.useConditionImage = true;
|
||||
inputStreams =
|
||||
Arrays.asList(
|
||||
"IMAGE:" + SOURCE_CONDITION_IMAGE_STREAM_NAME, "SELECT:" + SELECT_STREAM_NAME);
|
||||
outputStreams = Arrays.asList("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME);
|
||||
OutputHandler<ConditionImageResult, Void> conditionImageHandler = new OutputHandler<>();
|
||||
conditionImageHandler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ConditionImageResult, Void>() {
|
||||
@Override
|
||||
public ConditionImageResult convertToTaskResult(List<Packet> packets) {
|
||||
Packet packet = packets.get(0);
|
||||
return new AutoValue_ImageGenerator_ConditionImageResult(
|
||||
new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(packet)).build(),
|
||||
packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
conditionImageHandler.setHandleTimestampBoundChanges(true);
|
||||
imageGenerator.conditionImageGraphsContainerTaskRunner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageGeneratorOptions>builder()
|
||||
.setTaskName(ImageGenerator.class.getSimpleName())
|
||||
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
||||
.setTaskGraphName(CONDITION_IMAGE_GRAPHS_CONTAINER_NAME)
|
||||
.setInputStreams(inputStreams)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(generatorOptions)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
conditionImageHandler);
|
||||
imageGenerator.conditionTypeIndex = new HashMap<>();
|
||||
if (conditionOptions.faceConditionOptions().isPresent()) {
|
||||
imageGenerator.conditionTypeIndex.put(
|
||||
ConditionOptions.ConditionType.FACE, imageGenerator.conditionTypeIndex.size());
|
||||
}
|
||||
if (conditionOptions.edgeConditionOptions().isPresent()) {
|
||||
imageGenerator.conditionTypeIndex.put(
|
||||
ConditionOptions.ConditionType.EDGE, imageGenerator.conditionTypeIndex.size());
|
||||
}
|
||||
if (conditionOptions.depthConditionOptions().isPresent()) {
|
||||
imageGenerator.conditionTypeIndex.put(
|
||||
ConditionOptions.ConditionType.DEPTH, imageGenerator.conditionTypeIndex.size());
|
||||
}
|
||||
}
|
||||
return imageGenerator;
|
||||
}
|
||||
|
||||
private ImageGenerator(TaskRunner taskRunner) {
|
||||
super(taskRunner, RunningMode.IMAGE, "", "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
|
||||
* is created without condition options.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param iterations The total iterations to generate the image.
|
||||
* @param seed The random seed used during image generation.
|
||||
*/
|
||||
public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
|
||||
return runIterations(prompt, iterations, seed, null, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates an image based on the source image for iterations and the given random seed. Only
|
||||
* valid when the ImageGenerator is created with condition options.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param sourceConditionImage The source image used to create the condition image, which is used
|
||||
* as a guidance for the image generation.
|
||||
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
|
||||
* condition image.
|
||||
* @param iterations The total iterations to generate the image.
|
||||
* @param seed The random seed used during image generation.
|
||||
*/
|
||||
public ImageGeneratorResult generate(
|
||||
String prompt,
|
||||
MPImage sourceConditionImage,
|
||||
ConditionOptions.ConditionType conditionType,
|
||||
int iterations,
|
||||
int seed) {
|
||||
return runIterations(
|
||||
prompt,
|
||||
iterations,
|
||||
seed,
|
||||
createConditionImage(sourceConditionImage, conditionType),
|
||||
conditionTypeIndex.get(conditionType));
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the condition image of specified condition type from the source image. Currently support
|
||||
* face landmarks, depth image and edge image as the condition image.
|
||||
*
|
||||
* @param sourceConditionImage The source image used to create the condition image.
|
||||
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
|
||||
* condition image.
|
||||
*/
|
||||
public MPImage createConditionImage(
|
||||
MPImage sourceConditionImage, ConditionOptions.ConditionType conditionType) {
|
||||
if (!conditionTypeIndex.containsKey(conditionType)) {
|
||||
throw new IllegalArgumentException(
|
||||
"The condition type " + conditionType.name() + " is not created during initialization.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(
|
||||
SOURCE_CONDITION_IMAGE_STREAM_NAME,
|
||||
conditionImageGraphsContainerTaskRunner
|
||||
.getPacketCreator()
|
||||
.createImage(sourceConditionImage));
|
||||
inputPackets.put(
|
||||
SELECT_STREAM_NAME,
|
||||
conditionImageGraphsContainerTaskRunner
|
||||
.getPacketCreator()
|
||||
.createInt32(conditionTypeIndex.get(conditionType)));
|
||||
ConditionImageResult result =
|
||||
(ConditionImageResult) conditionImageGraphsContainerTaskRunner.process(inputPackets);
|
||||
return result.conditionImage();
|
||||
}
|
||||
|
||||
private ImageGeneratorResult runIterations(
|
||||
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
|
||||
ImageGeneratorResult result = null;
|
||||
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
|
||||
for (int i = 0; i < steps; i++) {
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
if (i == 0 && useConditionImage) {
|
||||
inputPackets.put(
|
||||
CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage));
|
||||
inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select));
|
||||
}
|
||||
inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt));
|
||||
inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(steps));
|
||||
inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(i));
|
||||
inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed));
|
||||
result = (ImageGeneratorResult) runner.process(inputPackets, timestamp++);
|
||||
}
|
||||
if (useConditionImage) {
|
||||
// Add condition image to the ImageGeneratorResult.
|
||||
return ImageGeneratorResult.create(
|
||||
result.generatedImage(), conditionImage, result.timestampMs());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Closes and cleans up the task runners. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
conditionImageGraphsContainerTaskRunner.close();
|
||||
}
|
||||
|
||||
/** A container class for the condition image. */
|
||||
@AutoValue
|
||||
protected abstract static class ConditionImageResult implements TaskResult {
|
||||
|
||||
public abstract MPImage conditionImage();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ImageGenerator}. */
|
||||
@AutoValue
|
||||
public abstract static class ImageGeneratorOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link ImageGeneratorOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Sets the text to image model directory storing the model weights. */
|
||||
public abstract Builder setText2ImageModelDirectory(String modelDirectory);
|
||||
|
||||
/** Sets the path to LoRA weights file. */
|
||||
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
|
||||
|
||||
public abstract Builder setResultListener(
|
||||
PureResultListener<ImageGeneratorResult> resultListener);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}}. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract ImageGeneratorOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link ImageGeneratorOptions} instance. */
|
||||
public final ImageGeneratorOptions build() {
|
||||
return autoBuild();
|
||||
}
|
||||
}
|
||||
|
||||
abstract String text2ImageModelDirectory();
|
||||
|
||||
abstract Optional<String> loraWeightsFilePath();
|
||||
|
||||
abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
private Optional<ConditionOptions> conditionOptions;
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
|
||||
.setText2ImageModelDirectory("");
|
||||
}
|
||||
|
||||
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
|
||||
@Override
|
||||
public Any convertToAnyProto() {
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder();
|
||||
if (conditionOptions != null && conditionOptions.isPresent()) {
|
||||
try {
|
||||
taskOptionsBuilder.mergeFrom(
|
||||
conditionOptions.get().convertToAnyProto().getValue(),
|
||||
ExtensionRegistryLite.getGeneratedRegistry());
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
Log.e(TAG, "Error converting ConditionOptions to proto. " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory());
|
||||
if (loraWeightsFilePath().isPresent()) {
|
||||
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
|
||||
ExternalFileProto.ExternalFile.newBuilder();
|
||||
externalFileBuilder.setFileName(loraWeightsFilePath().get());
|
||||
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
|
||||
}
|
||||
return Any.newBuilder()
|
||||
.setTypeUrl(
|
||||
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
|
||||
.setValue(taskOptionsBuilder.build().toByteString())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Options for setting up the conditions types and the plugin models */
|
||||
@AutoValue
|
||||
public abstract static class ConditionOptions extends TaskOptions {
|
||||
|
||||
/** The supported condition type. */
|
||||
public enum ConditionType {
|
||||
FACE,
|
||||
EDGE,
|
||||
DEPTH
|
||||
}
|
||||
|
||||
/** Builder for {@link ConditionOptions}. At least one type of condition options must be set. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
public abstract Builder setFaceConditionOptions(FaceConditionOptions faceConditionOptions);
|
||||
|
||||
public abstract Builder setDepthConditionOptions(DepthConditionOptions depthConditionOptions);
|
||||
|
||||
public abstract Builder setEdgeConditionOptions(EdgeConditionOptions edgeConditionOptions);
|
||||
|
||||
abstract ConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link ConditionOptions} instance. */
|
||||
public final ConditionOptions build() {
|
||||
ConditionOptions options = autoBuild();
|
||||
if (!options.faceConditionOptions().isPresent()
|
||||
&& !options.depthConditionOptions().isPresent()
|
||||
&& !options.edgeConditionOptions().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"At least one of `faceConditionOptions`, `depthConditionOptions` and"
|
||||
+ " `edgeConditionOptions` must be set.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract Optional<FaceConditionOptions> faceConditionOptions();
|
||||
|
||||
abstract Optional<DepthConditionOptions> depthConditionOptions();
|
||||
|
||||
abstract Optional<EdgeConditionOptions> edgeConditionOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions.Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an {@link ImageGeneratorOptions} to a {@link CalculatorOptions} protobuf message.
|
||||
*/
|
||||
@Override
|
||||
public Any convertToAnyProto() {
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder();
|
||||
if (faceConditionOptions().isPresent()) {
|
||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(faceConditionOptions().get().baseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
|
||||
.build())
|
||||
.build());
|
||||
}
|
||||
if (edgeConditionOptions().isPresent()) {
|
||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
|
||||
.build())
|
||||
.build());
|
||||
if (depthConditionOptions().isPresent()) {
|
||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(depthConditionOptions().get().baseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setDepthConditionTypeOptions(
|
||||
depthConditionOptions().get().convertToProto())
|
||||
.build())
|
||||
.build());
|
||||
}
|
||||
}
|
||||
return Any.newBuilder()
|
||||
.setTypeUrl(
|
||||
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
|
||||
.setValue(taskOptionsBuilder.build().toByteString())
|
||||
.build();
|
||||
}
|
||||
|
||||
/** Options for drawing face landmarks image. */
|
||||
@AutoValue
|
||||
public abstract static class FaceConditionOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link FaceConditionOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */
|
||||
public abstract Builder setFaceLandmarkerOptions(
|
||||
FaceLandmarkerOptions faceLandmarkerOptions);
|
||||
|
||||
abstract FaceConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link FaceConditionOptions} instance. */
|
||||
public final FaceConditionOptions build() {
|
||||
return autoBuild();
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract FaceLandmarkerOptions faceLandmarkerOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder();
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
|
||||
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
|
||||
.setFaceLandmarkerGraphOptions(
|
||||
FaceLandmarkerGraphOptions.newBuilder()
|
||||
.mergeFrom(
|
||||
faceLandmarkerOptions()
|
||||
.convertToCalculatorOptionsProto()
|
||||
.getExtension(FaceLandmarkerGraphOptions.ext))
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Options for detecting depth image. */
|
||||
@AutoValue
|
||||
public abstract static class DepthConditionOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link DepthConditionOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/** {@link ImageSegmenterOptions} used to detect depth image from the source image. */
|
||||
public abstract Builder setImageSegmenterOptions(
|
||||
ImageSegmenterOptions imageSegmenterOptions);
|
||||
|
||||
abstract DepthConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link DepthConditionOptions} instance. */
|
||||
public final DepthConditionOptions build() {
|
||||
DepthConditionOptions options = autoBuild();
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract ImageSegmenterOptions imageSegmenterOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
|
||||
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
|
||||
.setImageSegmenterGraphOptions(
|
||||
imageSegmenterOptions()
|
||||
.convertToCalculatorOptionsProto()
|
||||
.getExtension(ImageSegmenterGraphOptions.ext))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Options for detecting edge image. */
|
||||
@AutoValue
|
||||
public abstract static class EdgeConditionOptions {
|
||||
|
||||
/**
|
||||
* Builder for {@link EdgeConditionOptions}.
|
||||
*
|
||||
* <p>These parameters are used to config Canny edge algorithm of OpenCV.
|
||||
*
|
||||
* <p>See more details:
|
||||
* https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de
|
||||
*/
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/** First threshold for the hysteresis procedure. */
|
||||
public abstract Builder setThreshold1(Float threshold1);
|
||||
|
||||
/** Second threshold for the hysteresis procedure. */
|
||||
public abstract Builder setThreshold2(Float threshold2);
|
||||
|
||||
/** Aperture size for the Sobel operator. Typical range is 3~7. */
|
||||
public abstract Builder setApertureSize(Integer apertureSize);
|
||||
|
||||
/**
|
||||
* flag, indicating whether a more accurate L2 norm should be used to calculate the image
|
||||
* gradient magnitude ( L2gradient=true ), or whether the default L1 norm is enough (
|
||||
* L2gradient=false ).
|
||||
*/
|
||||
public abstract Builder setL2Gradient(Boolean l2Gradient);
|
||||
|
||||
abstract EdgeConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link EdgeConditionOptions} instance. */
|
||||
public final EdgeConditionOptions build() {
|
||||
return autoBuild();
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract Float threshold1();
|
||||
|
||||
abstract Float threshold2();
|
||||
|
||||
abstract Integer apertureSize();
|
||||
|
||||
abstract Boolean l2Gradient();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_EdgeConditionOptions.Builder()
|
||||
.setThreshold1(100f)
|
||||
.setThreshold2(200f)
|
||||
.setApertureSize(3)
|
||||
.setL2Gradient(false);
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.EdgeConditionTypeOptions convertToProto() {
|
||||
return ConditionedImageGraphOptions.EdgeConditionTypeOptions.newBuilder()
|
||||
.setThreshold1(threshold1())
|
||||
.setThreshold2(threshold2())
|
||||
.setApertureSize(apertureSize())
|
||||
.setL2Gradient(l2Gradient())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagegenerator;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Represents the image generation results generated by {@link ImageGenerator}. */
|
||||
@AutoValue
|
||||
public abstract class ImageGeneratorResult implements TaskResult {
|
||||
|
||||
/** Create an {@link ImageGeneratorResult} instance from the generated image. */
|
||||
public static ImageGeneratorResult create(
|
||||
MPImage generatedImage, MPImage conditionImage, long timestampMs) {
|
||||
return new AutoValue_ImageGeneratorResult(
|
||||
generatedImage, Optional.of(conditionImage), timestampMs);
|
||||
}
|
||||
|
||||
/** Create an {@link ImageGeneratorResult} instance from the generated image. */
|
||||
public static ImageGeneratorResult create(MPImage generatedImage, long timestampMs) {
|
||||
return new AutoValue_ImageGeneratorResult(generatedImage, Optional.empty(), timestampMs);
|
||||
}
|
||||
|
||||
public abstract MPImage generatedImage();
|
||||
|
||||
public abstract Optional<MPImage> conditionImage();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
Loading…
Reference in New Issue
Block a user