ImageGenerator Java API
PiperOrigin-RevId: 559310074
This commit is contained in:
		
							parent
							
								
									90781669cb
								
							
						
					
					
						commit
						2ebdb01d43
					
				|  | @ -0,0 +1,8 @@ | ||||||
|  | <?xml version="1.0" encoding="utf-8"?> | ||||||
|  | <manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||||||
|  |     package="com.google.mediapipe.tasks.vision.imagegenerator"> | ||||||
|  | 
 | ||||||
|  |     <uses-sdk android:minSdkVersion="24" | ||||||
|  |         android:targetSdkVersion="30" /> | ||||||
|  | 
 | ||||||
|  | </manifest> | ||||||
|  | @ -0,0 +1,660 @@ | ||||||
|  | // Copyright 2023 The MediaPipe Authors. | ||||||
|  | // | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | // you may not use this file except in compliance with the License. | ||||||
|  | // You may obtain a copy of the License at | ||||||
|  | // | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | // | ||||||
|  | // Unless required by applicable law or agreed to in writing, software | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | // See the License for the specific language governing permissions and | ||||||
|  | // limitations under the License. | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.tasks.vision.imagegenerator; | ||||||
|  | 
 | ||||||
|  | import android.content.Context; | ||||||
|  | import android.graphics.Bitmap; | ||||||
|  | import android.util.Log; | ||||||
|  | import com.google.auto.value.AutoValue; | ||||||
|  | import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; | ||||||
|  | import com.google.mediapipe.framework.AndroidPacketGetter; | ||||||
|  | import com.google.mediapipe.framework.Packet; | ||||||
|  | import com.google.mediapipe.framework.PacketGetter; | ||||||
|  | import com.google.mediapipe.framework.image.BitmapImageBuilder; | ||||||
|  | import com.google.mediapipe.framework.image.MPImage; | ||||||
|  | import com.google.mediapipe.tasks.core.BaseOptions; | ||||||
|  | import com.google.mediapipe.tasks.core.ErrorListener; | ||||||
|  | import com.google.mediapipe.tasks.core.OutputHandler; | ||||||
|  | import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener; | ||||||
|  | import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; | ||||||
|  | import com.google.mediapipe.tasks.core.TaskInfo; | ||||||
|  | import com.google.mediapipe.tasks.core.TaskOptions; | ||||||
|  | import com.google.mediapipe.tasks.core.TaskResult; | ||||||
|  | import com.google.mediapipe.tasks.core.TaskRunner; | ||||||
|  | import com.google.mediapipe.tasks.core.proto.ExternalFileProto; | ||||||
|  | import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; | ||||||
|  | import com.google.mediapipe.tasks.vision.core.RunningMode; | ||||||
|  | import com.google.mediapipe.tasks.vision.facelandmarker.FaceLandmarker.FaceLandmarkerOptions; | ||||||
|  | import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions; | ||||||
|  | import com.google.mediapipe.tasks.vision.imagegenerator.proto.ConditionedImageGraphOptionsProto.ConditionedImageGraphOptions; | ||||||
|  | import com.google.mediapipe.tasks.vision.imagegenerator.proto.ControlPluginGraphOptionsProto; | ||||||
|  | import com.google.mediapipe.tasks.vision.imagegenerator.proto.ImageGeneratorGraphOptionsProto; | ||||||
|  | import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions; | ||||||
|  | import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions; | ||||||
|  | import com.google.protobuf.Any; | ||||||
|  | import com.google.protobuf.ExtensionRegistryLite; | ||||||
|  | import com.google.protobuf.InvalidProtocolBufferException; | ||||||
|  | import java.util.ArrayList; | ||||||
|  | import java.util.Arrays; | ||||||
|  | import java.util.HashMap; | ||||||
|  | import java.util.List; | ||||||
|  | import java.util.Map; | ||||||
|  | import java.util.Optional; | ||||||
|  | import javax.annotation.Nullable; | ||||||
|  | 
 | ||||||
|  | /** Performs image generation from a text prompt. */ | ||||||
|  | public final class ImageGenerator extends BaseVisionTaskApi { | ||||||
|  | 
 | ||||||
|  |   private static final String STEPS_STREAM_NAME = "steps"; | ||||||
|  |   private static final String ITERATION_STREAM_NAME = "iteration"; | ||||||
|  |   private static final String PROMPT_STREAM_NAME = "prompt"; | ||||||
|  |   private static final String RAND_SEED_STREAM_NAME = "rand_seed"; | ||||||
|  |   private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image"; | ||||||
|  |   private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image"; | ||||||
|  |   private static final String SELECT_STREAM_NAME = "select"; | ||||||
|  |   private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0; | ||||||
|  |   private static final int STEPS_OUT_STREAM_INDEX = 1; | ||||||
|  |   private static final int ITERATION_OUT_STREAM_INDEX = 2; | ||||||
|  |   private static final String TASK_GRAPH_NAME = | ||||||
|  |       "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph"; | ||||||
|  |   private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME = | ||||||
|  |       "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer"; | ||||||
|  |   private static final String TAG = "ImageGenerator"; | ||||||
|  |   private TaskRunner conditionImageGraphsContainerTaskRunner; | ||||||
|  |   private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex; | ||||||
|  |   private boolean useConditionImage = false; | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}. | ||||||
|  |    * | ||||||
|  |    * @param context an Android {@link Context}. | ||||||
|  |    * @param generatorOptions an {@link ImageGeneratorOptions} instance. | ||||||
|  |    * @throws MediaPipeException if there is an error during {@link ImageGenerator} creation. | ||||||
|  |    */ | ||||||
|  |   public static ImageGenerator createFromOptions( | ||||||
|  |       Context context, ImageGeneratorOptions generatorOptions) { | ||||||
|  |     return createFromOptions(context, generatorOptions, null); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates an {@link ImageGenerator} instance, from {@link ImageGeneratorOptions} and {@link | ||||||
|  |    * ConditionOptions}, if plugin models are used to generate an image based on the condition image. | ||||||
|  |    * | ||||||
|  |    * @param context an Android {@link Context}. | ||||||
|  |    * @param generatorOptions an {@link ImageGeneratorOptions} instance. | ||||||
|  |    * @param conditionOptions an {@link ConditionOptions} instance. | ||||||
|  |    * @throws MediaPipeException if there is an error during {@link ImageGenerator} creation. | ||||||
|  |    */ | ||||||
|  |   public static ImageGenerator createFromOptions( | ||||||
|  |       Context context, | ||||||
|  |       ImageGeneratorOptions generatorOptions, | ||||||
|  |       @Nullable ConditionOptions conditionOptions) { | ||||||
|  |     List<String> inputStreams = new ArrayList<>(); | ||||||
|  |     inputStreams.addAll( | ||||||
|  |         Arrays.asList( | ||||||
|  |             "STEPS:" + STEPS_STREAM_NAME, | ||||||
|  |             "ITERATION:" + ITERATION_STREAM_NAME, | ||||||
|  |             "PROMPT:" + PROMPT_STREAM_NAME, | ||||||
|  |             "RAND_SEED:" + RAND_SEED_STREAM_NAME)); | ||||||
|  |     final boolean useConditionImage = conditionOptions != null; | ||||||
|  |     if (useConditionImage) { | ||||||
|  |       inputStreams.add("SELECT:" + SELECT_STREAM_NAME); | ||||||
|  |       inputStreams.add("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME); | ||||||
|  |       generatorOptions.conditionOptions = Optional.of(conditionOptions); | ||||||
|  |     } | ||||||
|  |     List<String> outputStreams = | ||||||
|  |         Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out"); | ||||||
|  | 
 | ||||||
|  |     OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>(); | ||||||
|  |     handler.setOutputPacketConverter( | ||||||
|  |         new OutputHandler.OutputPacketConverter<ImageGeneratorResult, Void>() { | ||||||
|  |           @Override | ||||||
|  |           @Nullable | ||||||
|  |           public ImageGeneratorResult convertToTaskResult(List<Packet> packets) { | ||||||
|  |             int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX)); | ||||||
|  |             int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX)); | ||||||
|  |             Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps); | ||||||
|  |             if (iteration != steps - 1) { | ||||||
|  |               return null; | ||||||
|  |             } | ||||||
|  |             Log.i("ImageGenerator", "processing generated image"); | ||||||
|  |             Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX); | ||||||
|  |             Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet); | ||||||
|  |             BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap); | ||||||
|  |             return ImageGeneratorResult.create( | ||||||
|  |                 bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND); | ||||||
|  |           } | ||||||
|  | 
 | ||||||
|  |           @Override | ||||||
|  |           public Void convertToTaskInput(List<Packet> packets) { | ||||||
|  |             return null; | ||||||
|  |           } | ||||||
|  |         }); | ||||||
|  |     handler.setHandleTimestampBoundChanges(true); | ||||||
|  |     if (generatorOptions.resultListener().isPresent()) { | ||||||
|  |       ResultListener<ImageGeneratorResult, Void> resultListener = | ||||||
|  |           new ResultListener<ImageGeneratorResult, Void>() { | ||||||
|  |             @Override | ||||||
|  |             public void run(ImageGeneratorResult imageGeneratorResult, Void input) { | ||||||
|  |               generatorOptions.resultListener().get().run(imageGeneratorResult); | ||||||
|  |             } | ||||||
|  |           }; | ||||||
|  |       handler.setResultListener(resultListener); | ||||||
|  |     } | ||||||
|  |     generatorOptions.errorListener().ifPresent(handler::setErrorListener); | ||||||
|  |     TaskRunner runner = | ||||||
|  |         TaskRunner.create( | ||||||
|  |             context, | ||||||
|  |             TaskInfo.<ImageGeneratorOptions>builder() | ||||||
|  |                 .setTaskName(ImageGenerator.class.getSimpleName()) | ||||||
|  |                 .setTaskRunningModeName(RunningMode.IMAGE.name()) | ||||||
|  |                 .setTaskGraphName(TASK_GRAPH_NAME) | ||||||
|  |                 .setInputStreams(inputStreams) | ||||||
|  |                 .setOutputStreams(outputStreams) | ||||||
|  |                 .setTaskOptions(generatorOptions) | ||||||
|  |                 .setEnableFlowLimiting(false) | ||||||
|  |                 .build(), | ||||||
|  |             handler); | ||||||
|  |     ImageGenerator imageGenerator = new ImageGenerator(runner); | ||||||
|  |     if (useConditionImage) { | ||||||
|  |       imageGenerator.useConditionImage = true; | ||||||
|  |       inputStreams = | ||||||
|  |           Arrays.asList( | ||||||
|  |               "IMAGE:" + SOURCE_CONDITION_IMAGE_STREAM_NAME, "SELECT:" + SELECT_STREAM_NAME); | ||||||
|  |       outputStreams = Arrays.asList("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME); | ||||||
|  |       OutputHandler<ConditionImageResult, Void> conditionImageHandler = new OutputHandler<>(); | ||||||
|  |       conditionImageHandler.setOutputPacketConverter( | ||||||
|  |           new OutputHandler.OutputPacketConverter<ConditionImageResult, Void>() { | ||||||
|  |             @Override | ||||||
|  |             public ConditionImageResult convertToTaskResult(List<Packet> packets) { | ||||||
|  |               Packet packet = packets.get(0); | ||||||
|  |               return new AutoValue_ImageGenerator_ConditionImageResult( | ||||||
|  |                   new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(packet)).build(), | ||||||
|  |                   packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             @Override | ||||||
|  |             public Void convertToTaskInput(List<Packet> packets) { | ||||||
|  |               return null; | ||||||
|  |             } | ||||||
|  |           }); | ||||||
|  |       conditionImageHandler.setHandleTimestampBoundChanges(true); | ||||||
|  |       imageGenerator.conditionImageGraphsContainerTaskRunner = | ||||||
|  |           TaskRunner.create( | ||||||
|  |               context, | ||||||
|  |               TaskInfo.<ImageGeneratorOptions>builder() | ||||||
|  |                   .setTaskName(ImageGenerator.class.getSimpleName()) | ||||||
|  |                   .setTaskRunningModeName(RunningMode.IMAGE.name()) | ||||||
|  |                   .setTaskGraphName(CONDITION_IMAGE_GRAPHS_CONTAINER_NAME) | ||||||
|  |                   .setInputStreams(inputStreams) | ||||||
|  |                   .setOutputStreams(outputStreams) | ||||||
|  |                   .setTaskOptions(generatorOptions) | ||||||
|  |                   .setEnableFlowLimiting(false) | ||||||
|  |                   .build(), | ||||||
|  |               conditionImageHandler); | ||||||
|  |       imageGenerator.conditionTypeIndex = new HashMap<>(); | ||||||
|  |       if (conditionOptions.faceConditionOptions().isPresent()) { | ||||||
|  |         imageGenerator.conditionTypeIndex.put( | ||||||
|  |             ConditionOptions.ConditionType.FACE, imageGenerator.conditionTypeIndex.size()); | ||||||
|  |       } | ||||||
|  |       if (conditionOptions.edgeConditionOptions().isPresent()) { | ||||||
|  |         imageGenerator.conditionTypeIndex.put( | ||||||
|  |             ConditionOptions.ConditionType.EDGE, imageGenerator.conditionTypeIndex.size()); | ||||||
|  |       } | ||||||
|  |       if (conditionOptions.depthConditionOptions().isPresent()) { | ||||||
|  |         imageGenerator.conditionTypeIndex.put( | ||||||
|  |             ConditionOptions.ConditionType.DEPTH, imageGenerator.conditionTypeIndex.size()); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     return imageGenerator; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private ImageGenerator(TaskRunner taskRunner) { | ||||||
|  |     super(taskRunner, RunningMode.IMAGE, "", ""); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Generates an image for iterations and the given random seed. Only valid when the ImageGenerator | ||||||
|  |    * is created without condition options. | ||||||
|  |    * | ||||||
|  |    * @param prompt The text prompt describing the image to be generated. | ||||||
|  |    * @param iterations The total iterations to generate the image. | ||||||
|  |    * @param seed The random seed used during image generation. | ||||||
|  |    */ | ||||||
|  |   public ImageGeneratorResult generate(String prompt, int iterations, int seed) { | ||||||
|  |     return runIterations(prompt, iterations, seed, null, 0); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Generates an image based on the source image for iterations and the given random seed. Only | ||||||
|  |    * valid when the ImageGenerator is created with condition options. | ||||||
|  |    * | ||||||
|  |    * @param prompt The text prompt describing the image to be generated. | ||||||
|  |    * @param sourceConditionImage The source image used to create the condition image, which is used | ||||||
|  |    *     as a guidance for the image generation. | ||||||
|  |    * @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of | ||||||
|  |    *     condition image. | ||||||
|  |    * @param iterations The total iterations to generate the image. | ||||||
|  |    * @param seed The random seed used during image generation. | ||||||
|  |    */ | ||||||
|  |   public ImageGeneratorResult generate( | ||||||
|  |       String prompt, | ||||||
|  |       MPImage sourceConditionImage, | ||||||
|  |       ConditionOptions.ConditionType conditionType, | ||||||
|  |       int iterations, | ||||||
|  |       int seed) { | ||||||
|  |     return runIterations( | ||||||
|  |         prompt, | ||||||
|  |         iterations, | ||||||
|  |         seed, | ||||||
|  |         createConditionImage(sourceConditionImage, conditionType), | ||||||
|  |         conditionTypeIndex.get(conditionType)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Create the condition image of specified condition type from the source image. Currently support | ||||||
|  |    * face landmarks, depth image and edge image as the condition image. | ||||||
|  |    * | ||||||
|  |    * @param sourceConditionImage The source image used to create the condition image. | ||||||
|  |    * @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of | ||||||
|  |    *     condition image. | ||||||
|  |    */ | ||||||
|  |   public MPImage createConditionImage( | ||||||
|  |       MPImage sourceConditionImage, ConditionOptions.ConditionType conditionType) { | ||||||
|  |     if (!conditionTypeIndex.containsKey(conditionType)) { | ||||||
|  |       throw new IllegalArgumentException( | ||||||
|  |           "The condition type " + conditionType.name() + " is not created during initialization."); | ||||||
|  |     } | ||||||
|  |     Map<String, Packet> inputPackets = new HashMap<>(); | ||||||
|  |     inputPackets.put( | ||||||
|  |         SOURCE_CONDITION_IMAGE_STREAM_NAME, | ||||||
|  |         conditionImageGraphsContainerTaskRunner | ||||||
|  |             .getPacketCreator() | ||||||
|  |             .createImage(sourceConditionImage)); | ||||||
|  |     inputPackets.put( | ||||||
|  |         SELECT_STREAM_NAME, | ||||||
|  |         conditionImageGraphsContainerTaskRunner | ||||||
|  |             .getPacketCreator() | ||||||
|  |             .createInt32(conditionTypeIndex.get(conditionType))); | ||||||
|  |     ConditionImageResult result = | ||||||
|  |         (ConditionImageResult) conditionImageGraphsContainerTaskRunner.process(inputPackets); | ||||||
|  |     return result.conditionImage(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private ImageGeneratorResult runIterations( | ||||||
|  |       String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) { | ||||||
|  |     ImageGeneratorResult result = null; | ||||||
|  |     long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND; | ||||||
|  |     for (int i = 0; i < steps; i++) { | ||||||
|  |       Map<String, Packet> inputPackets = new HashMap<>(); | ||||||
|  |       if (i == 0 && useConditionImage) { | ||||||
|  |         inputPackets.put( | ||||||
|  |             CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage)); | ||||||
|  |         inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select)); | ||||||
|  |       } | ||||||
|  |       inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt)); | ||||||
|  |       inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(steps)); | ||||||
|  |       inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(i)); | ||||||
|  |       inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed)); | ||||||
|  |       result = (ImageGeneratorResult) runner.process(inputPackets, timestamp++); | ||||||
|  |     } | ||||||
|  |     if (useConditionImage) { | ||||||
|  |       // Add condition image to the ImageGeneratorResult. | ||||||
|  |       return ImageGeneratorResult.create( | ||||||
|  |           result.generatedImage(), conditionImage, result.timestampMs()); | ||||||
|  |     } | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Closes and cleans up the task runners. */ | ||||||
|  |   @Override | ||||||
|  |   public void close() { | ||||||
|  |     runner.close(); | ||||||
|  |     conditionImageGraphsContainerTaskRunner.close(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** A container class for the condition image. */ | ||||||
|  |   @AutoValue | ||||||
|  |   protected abstract static class ConditionImageResult implements TaskResult { | ||||||
|  | 
 | ||||||
|  |     public abstract MPImage conditionImage(); | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public abstract long timestampMs(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Options for setting up an {@link ImageGenerator}. */ | ||||||
|  |   @AutoValue | ||||||
|  |   public abstract static class ImageGeneratorOptions extends TaskOptions { | ||||||
|  | 
 | ||||||
|  |     /** Builder for {@link ImageGeneratorOptions}. */ | ||||||
|  |     @AutoValue.Builder | ||||||
|  |     public abstract static class Builder { | ||||||
|  | 
 | ||||||
|  |       /** Sets the text to image model directory storing the model weights. */ | ||||||
|  |       public abstract Builder setText2ImageModelDirectory(String modelDirectory); | ||||||
|  | 
 | ||||||
|  |       /** Sets the path to LoRA weights file. */ | ||||||
|  |       public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath); | ||||||
|  | 
 | ||||||
|  |       public abstract Builder setResultListener( | ||||||
|  |           PureResultListener<ImageGeneratorResult> resultListener); | ||||||
|  | 
 | ||||||
|  |       /** Sets an optional {@link ErrorListener}}. */ | ||||||
|  |       public abstract Builder setErrorListener(ErrorListener value); | ||||||
|  | 
 | ||||||
|  |       abstract ImageGeneratorOptions autoBuild(); | ||||||
|  | 
 | ||||||
|  |       /** Validates and builds the {@link ImageGeneratorOptions} instance. */ | ||||||
|  |       public final ImageGeneratorOptions build() { | ||||||
|  |         return autoBuild(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     abstract String text2ImageModelDirectory(); | ||||||
|  | 
 | ||||||
|  |     abstract Optional<String> loraWeightsFilePath(); | ||||||
|  | 
 | ||||||
|  |     abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener(); | ||||||
|  | 
 | ||||||
|  |     abstract Optional<ErrorListener> errorListener(); | ||||||
|  | 
 | ||||||
|  |     private Optional<ConditionOptions> conditionOptions; | ||||||
|  | 
 | ||||||
|  |     public static Builder builder() { | ||||||
|  |       return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder() | ||||||
|  |           .setText2ImageModelDirectory(""); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */ | ||||||
|  |     @Override | ||||||
|  |     public Any convertToAnyProto() { | ||||||
|  |       ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder = | ||||||
|  |           ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder(); | ||||||
|  |       if (conditionOptions != null && conditionOptions.isPresent()) { | ||||||
|  |         try { | ||||||
|  |           taskOptionsBuilder.mergeFrom( | ||||||
|  |               conditionOptions.get().convertToAnyProto().getValue(), | ||||||
|  |               ExtensionRegistryLite.getGeneratedRegistry()); | ||||||
|  |         } catch (InvalidProtocolBufferException e) { | ||||||
|  |           Log.e(TAG, "Error converting ConditionOptions to proto. " + e.getMessage()); | ||||||
|  |           e.printStackTrace(); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory()); | ||||||
|  |       if (loraWeightsFilePath().isPresent()) { | ||||||
|  |         ExternalFileProto.ExternalFile.Builder externalFileBuilder = | ||||||
|  |             ExternalFileProto.ExternalFile.newBuilder(); | ||||||
|  |         externalFileBuilder.setFileName(loraWeightsFilePath().get()); | ||||||
|  |         taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build()); | ||||||
|  |       } | ||||||
|  |       return Any.newBuilder() | ||||||
|  |           .setTypeUrl( | ||||||
|  |               "type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions") | ||||||
|  |           .setValue(taskOptionsBuilder.build().toByteString()) | ||||||
|  |           .build(); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Options for setting up the conditions types and the plugin models */ | ||||||
|  |   @AutoValue | ||||||
|  |   public abstract static class ConditionOptions extends TaskOptions { | ||||||
|  | 
 | ||||||
|  |     /** The supported condition type. */ | ||||||
|  |     public enum ConditionType { | ||||||
|  |       FACE, | ||||||
|  |       EDGE, | ||||||
|  |       DEPTH | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** Builder for {@link ConditionOptions}. At least one type of condition options must be set. */ | ||||||
|  |     @AutoValue.Builder | ||||||
|  |     public abstract static class Builder { | ||||||
|  |       public abstract Builder setFaceConditionOptions(FaceConditionOptions faceConditionOptions); | ||||||
|  | 
 | ||||||
|  |       public abstract Builder setDepthConditionOptions(DepthConditionOptions depthConditionOptions); | ||||||
|  | 
 | ||||||
|  |       public abstract Builder setEdgeConditionOptions(EdgeConditionOptions edgeConditionOptions); | ||||||
|  | 
 | ||||||
|  |       abstract ConditionOptions autoBuild(); | ||||||
|  | 
 | ||||||
|  |       /** Validates and builds the {@link ConditionOptions} instance. */ | ||||||
|  |       public final ConditionOptions build() { | ||||||
|  |         ConditionOptions options = autoBuild(); | ||||||
|  |         if (!options.faceConditionOptions().isPresent() | ||||||
|  |             && !options.depthConditionOptions().isPresent() | ||||||
|  |             && !options.edgeConditionOptions().isPresent()) { | ||||||
|  |           throw new IllegalArgumentException( | ||||||
|  |               "At least one of `faceConditionOptions`, `depthConditionOptions` and" | ||||||
|  |                   + " `edgeConditionOptions` must be set."); | ||||||
|  |         } | ||||||
|  |         return options; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     abstract Optional<FaceConditionOptions> faceConditionOptions(); | ||||||
|  | 
 | ||||||
|  |     abstract Optional<DepthConditionOptions> depthConditionOptions(); | ||||||
|  | 
 | ||||||
|  |     abstract Optional<EdgeConditionOptions> edgeConditionOptions(); | ||||||
|  | 
 | ||||||
|  |     public static Builder builder() { | ||||||
|  |       return new AutoValue_ImageGenerator_ConditionOptions.Builder(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Converts an {@link ImageGeneratorOptions} to a {@link CalculatorOptions} protobuf message. | ||||||
|  |      */ | ||||||
|  |     @Override | ||||||
|  |     public Any convertToAnyProto() { | ||||||
|  |       ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder = | ||||||
|  |           ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder(); | ||||||
|  |       if (faceConditionOptions().isPresent()) { | ||||||
|  |         taskOptionsBuilder.addControlPluginGraphsOptions( | ||||||
|  |             ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() | ||||||
|  |                 .setBaseOptions( | ||||||
|  |                     convertBaseOptionsToProto(faceConditionOptions().get().baseOptions())) | ||||||
|  |                 .setConditionedImageGraphOptions( | ||||||
|  |                     ConditionedImageGraphOptions.newBuilder() | ||||||
|  |                         .setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto()) | ||||||
|  |                         .build()) | ||||||
|  |                 .build()); | ||||||
|  |       } | ||||||
|  |       if (edgeConditionOptions().isPresent()) { | ||||||
|  |         taskOptionsBuilder.addControlPluginGraphsOptions( | ||||||
|  |             ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() | ||||||
|  |                 .setBaseOptions( | ||||||
|  |                     convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions())) | ||||||
|  |                 .setConditionedImageGraphOptions( | ||||||
|  |                     ConditionedImageGraphOptions.newBuilder() | ||||||
|  |                         .setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto()) | ||||||
|  |                         .build()) | ||||||
|  |                 .build()); | ||||||
|  |         if (depthConditionOptions().isPresent()) { | ||||||
|  |           taskOptionsBuilder.addControlPluginGraphsOptions( | ||||||
|  |               ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() | ||||||
|  |                   .setBaseOptions( | ||||||
|  |                       convertBaseOptionsToProto(depthConditionOptions().get().baseOptions())) | ||||||
|  |                   .setConditionedImageGraphOptions( | ||||||
|  |                       ConditionedImageGraphOptions.newBuilder() | ||||||
|  |                           .setDepthConditionTypeOptions( | ||||||
|  |                               depthConditionOptions().get().convertToProto()) | ||||||
|  |                           .build()) | ||||||
|  |                   .build()); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       return Any.newBuilder() | ||||||
|  |           .setTypeUrl( | ||||||
|  |               "type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions") | ||||||
|  |           .setValue(taskOptionsBuilder.build().toByteString()) | ||||||
|  |           .build(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** Options for drawing face landmarks image. */ | ||||||
|  |     @AutoValue | ||||||
|  |     public abstract static class FaceConditionOptions extends TaskOptions { | ||||||
|  | 
 | ||||||
|  |       /** Builder for {@link FaceConditionOptions}. */ | ||||||
|  |       @AutoValue.Builder | ||||||
|  |       public abstract static class Builder { | ||||||
|  |         /** Set the base options for plugin model. */ | ||||||
|  |         public abstract Builder setBaseOptions(BaseOptions baseOptions); | ||||||
|  | 
 | ||||||
|  |         /* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */ | ||||||
|  |         public abstract Builder setFaceLandmarkerOptions( | ||||||
|  |             FaceLandmarkerOptions faceLandmarkerOptions); | ||||||
|  | 
 | ||||||
|  |         abstract FaceConditionOptions autoBuild(); | ||||||
|  | 
 | ||||||
|  |         /** Validates and builds the {@link FaceConditionOptions} instance. */ | ||||||
|  |         public final FaceConditionOptions build() { | ||||||
|  |           return autoBuild(); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       abstract BaseOptions baseOptions(); | ||||||
|  | 
 | ||||||
|  |       abstract FaceLandmarkerOptions faceLandmarkerOptions(); | ||||||
|  | 
 | ||||||
|  |       public static Builder builder() { | ||||||
|  |         return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder(); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() { | ||||||
|  |         return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder() | ||||||
|  |             .setFaceLandmarkerGraphOptions( | ||||||
|  |                 FaceLandmarkerGraphOptions.newBuilder() | ||||||
|  |                     .mergeFrom( | ||||||
|  |                         faceLandmarkerOptions() | ||||||
|  |                             .convertToCalculatorOptionsProto() | ||||||
|  |                             .getExtension(FaceLandmarkerGraphOptions.ext)) | ||||||
|  |                     .build()) | ||||||
|  |             .build(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** Options for detecting depth image. */ | ||||||
|  |     @AutoValue | ||||||
|  |     public abstract static class DepthConditionOptions extends TaskOptions { | ||||||
|  | 
 | ||||||
|  |       /** Builder for {@link DepthConditionOptions}. */ | ||||||
|  |       @AutoValue.Builder | ||||||
|  |       public abstract static class Builder { | ||||||
|  | 
 | ||||||
|  |         /** Set the base options for plugin model. */ | ||||||
|  |         public abstract Builder setBaseOptions(BaseOptions baseOptions); | ||||||
|  | 
 | ||||||
|  |         /** {@link ImageSegmenterOptions} used to detect depth image from the source image. */ | ||||||
|  |         public abstract Builder setImageSegmenterOptions( | ||||||
|  |             ImageSegmenterOptions imageSegmenterOptions); | ||||||
|  | 
 | ||||||
|  |         abstract DepthConditionOptions autoBuild(); | ||||||
|  | 
 | ||||||
|  |         /** Validates and builds the {@link DepthConditionOptions} instance. */ | ||||||
|  |         public final DepthConditionOptions build() { | ||||||
|  |           DepthConditionOptions options = autoBuild(); | ||||||
|  |           return options; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       abstract BaseOptions baseOptions(); | ||||||
|  | 
 | ||||||
|  |       abstract ImageSegmenterOptions imageSegmenterOptions(); | ||||||
|  | 
 | ||||||
|  |       public static Builder builder() { | ||||||
|  |         return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder(); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() { | ||||||
|  |         return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder() | ||||||
|  |             .setImageSegmenterGraphOptions( | ||||||
|  |                 imageSegmenterOptions() | ||||||
|  |                     .convertToCalculatorOptionsProto() | ||||||
|  |                     .getExtension(ImageSegmenterGraphOptions.ext)) | ||||||
|  |             .build(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** Options for detecting edge image. */ | ||||||
|  |     @AutoValue | ||||||
|  |     public abstract static class EdgeConditionOptions { | ||||||
|  | 
 | ||||||
|  |       /** | ||||||
|  |        * Builder for {@link EdgeConditionOptions}. | ||||||
|  |        * | ||||||
|  |        * <p>These parameters are used to config Canny edge algorithm of OpenCV. | ||||||
|  |        * | ||||||
|  |        * <p>See more details: | ||||||
|  |        * https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de | ||||||
|  |        */ | ||||||
|  |       @AutoValue.Builder | ||||||
|  |       public abstract static class Builder { | ||||||
|  | 
 | ||||||
|  |         /** Set the base options for plugin model. */ | ||||||
|  |         public abstract Builder setBaseOptions(BaseOptions baseOptions); | ||||||
|  | 
 | ||||||
|  |         /** First threshold for the hysteresis procedure. */ | ||||||
|  |         public abstract Builder setThreshold1(Float threshold1); | ||||||
|  | 
 | ||||||
|  |         /** Second threshold for the hysteresis procedure. */ | ||||||
|  |         public abstract Builder setThreshold2(Float threshold2); | ||||||
|  | 
 | ||||||
|  |         /** Aperture size for the Sobel operator. Typical range is 3~7. */ | ||||||
|  |         public abstract Builder setApertureSize(Integer apertureSize); | ||||||
|  | 
 | ||||||
|  |         /** | ||||||
|  |          * flag, indicating whether a more accurate L2 norm should be used to calculate the image | ||||||
|  |          * gradient magnitude ( L2gradient=true ), or whether the default L1 norm is enough ( | ||||||
|  |          * L2gradient=false ). | ||||||
|  |          */ | ||||||
|  |         public abstract Builder setL2Gradient(Boolean l2Gradient); | ||||||
|  | 
 | ||||||
|  |         abstract EdgeConditionOptions autoBuild(); | ||||||
|  | 
 | ||||||
|  |         /** Validates and builds the {@link EdgeConditionOptions} instance. */ | ||||||
|  |         public final EdgeConditionOptions build() { | ||||||
|  |           return autoBuild(); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       abstract BaseOptions baseOptions(); | ||||||
|  | 
 | ||||||
|  |       abstract Float threshold1(); | ||||||
|  | 
 | ||||||
|  |       abstract Float threshold2(); | ||||||
|  | 
 | ||||||
|  |       abstract Integer apertureSize(); | ||||||
|  | 
 | ||||||
|  |       abstract Boolean l2Gradient(); | ||||||
|  | 
 | ||||||
|  |       public static Builder builder() { | ||||||
|  |         return new AutoValue_ImageGenerator_ConditionOptions_EdgeConditionOptions.Builder() | ||||||
|  |             .setThreshold1(100f) | ||||||
|  |             .setThreshold2(200f) | ||||||
|  |             .setApertureSize(3) | ||||||
|  |             .setL2Gradient(false); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       ConditionedImageGraphOptions.EdgeConditionTypeOptions convertToProto() { | ||||||
|  |         return ConditionedImageGraphOptions.EdgeConditionTypeOptions.newBuilder() | ||||||
|  |             .setThreshold1(threshold1()) | ||||||
|  |             .setThreshold2(threshold2()) | ||||||
|  |             .setApertureSize(apertureSize()) | ||||||
|  |             .setL2Gradient(l2Gradient()) | ||||||
|  |             .build(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,44 @@ | ||||||
|  | // Copyright 2023 The MediaPipe Authors. | ||||||
|  | // | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | // you may not use this file except in compliance with the License. | ||||||
|  | // You may obtain a copy of the License at | ||||||
|  | // | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | // | ||||||
|  | // Unless required by applicable law or agreed to in writing, software | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | // See the License for the specific language governing permissions and | ||||||
|  | // limitations under the License. | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.tasks.vision.imagegenerator; | ||||||
|  | 
 | ||||||
|  | import com.google.auto.value.AutoValue; | ||||||
|  | import com.google.mediapipe.framework.image.MPImage; | ||||||
|  | import com.google.mediapipe.tasks.core.TaskResult; | ||||||
|  | import java.util.Optional; | ||||||
|  | 
 | ||||||
|  | /** Represents the image generation results generated by {@link ImageGenerator}. */ | ||||||
|  | @AutoValue | ||||||
|  | public abstract class ImageGeneratorResult implements TaskResult { | ||||||
|  | 
 | ||||||
|  |   /** Create an {@link ImageGeneratorResult} instance from the generated image. */ | ||||||
|  |   public static ImageGeneratorResult create( | ||||||
|  |       MPImage generatedImage, MPImage conditionImage, long timestampMs) { | ||||||
|  |     return new AutoValue_ImageGeneratorResult( | ||||||
|  |         generatedImage, Optional.of(conditionImage), timestampMs); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Create an {@link ImageGeneratorResult} instance from the generated image. */ | ||||||
|  |   public static ImageGeneratorResult create(MPImage generatedImage, long timestampMs) { | ||||||
|  |     return new AutoValue_ImageGeneratorResult(generatedImage, Optional.empty(), timestampMs); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   public abstract MPImage generatedImage(); | ||||||
|  | 
 | ||||||
|  |   public abstract Optional<MPImage> conditionImage(); | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public abstract long timestampMs(); | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user