diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc b/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc index 639a73e34..efbfd86e9 100644 --- a/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc @@ -68,6 +68,7 @@ constexpr absl::string_view kRandSeedTag = "RAND_SEED"; constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS"; constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE"; constexpr absl::string_view kSelectTag = "SELECT"; +constexpr absl::string_view kShowResultTag = "SHOW_RESULT"; constexpr absl::string_view kMetadataFilename = "metadata"; constexpr absl::string_view kLoraRankStr = "lora_rank"; @@ -78,6 +79,7 @@ struct ImageGeneratorInputs { Source rand_seed; std::optional> condition_image; std::optional> select_condition_type; + std::optional> show_result; }; struct ImageGeneratorOutputs { @@ -209,6 +211,9 @@ REGISTER_MEDIAPIPE_GRAPH( // valid, if condtrol plugin graph options are set in the graph options. // SELECT - int // The index of the selected the control plugin graph. +// SHOW_RESULT - bool @Optional +// Whether to show the diffusion result at the current step. If this stream +// is not empty, regardless show_every_n_iteration in the options. // // Outputs: // IMAGE - Image @@ -218,6 +223,9 @@ REGISTER_MEDIAPIPE_GRAPH( // ITERATION - int @optional // The current iteration in the generating steps. The same as ITERATION // input. +// SHOW_RESULT - bool @Optional +// Whether to show the diffusion result at the current step. The same as +// input SHOW_RESULT. class ImageGeneratorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( @@ -239,6 +247,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph { condition_image = graph.In(kConditionImageTag).Cast(); select_condition_type = graph.In(kSelectTag).Cast(); } + std::optional> show_result; + if (HasInput(sc->OriginalNode(), kShowResultTag)) { + show_result = graph.In(kShowResultTag).Cast(); + } ASSIGN_OR_RETURN( auto outputs, BuildImageGeneratorGraph( @@ -251,6 +263,7 @@ class ImageGeneratorGraph : public core::ModelTaskGraph { /*rand_seed=*/graph.In(kRandSeedTag).Cast(), /*condition_image*/ condition_image, /*select_condition_type*/ select_condition_type, + /*show_result*/ show_result, }, graph)); outputs.generated_image >> graph.Out(kImageTag).Cast(); @@ -261,6 +274,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph { graph.In(kStepsTag) >> pass_through.In(1); pass_through.Out(0) >> graph[Output::Optional(kIterationTag)]; pass_through.Out(1) >> graph[Output::Optional(kStepsTag)]; + if (HasOutput(sc->OriginalNode(), kShowResultTag)) { + graph.In(kShowResultTag) >> pass_through.In(2); + pass_through.Out(2) >> graph[Output::Optional(kShowResultTag)]; + } return graph.GetConfig(); } @@ -299,15 +316,22 @@ class ImageGeneratorGraph : public core::ModelTaskGraph { inputs.steps >> stable_diff.In(kStepsTag); inputs.iteration >> stable_diff.In(kIterationTag); inputs.rand_seed >> stable_diff.In(kRandSeedTag); + if (inputs.show_result.has_value()) { + *inputs.show_result >> stable_diff.In(kShowResultTag); + } mediapipe::StableDiffusionIterateCalculatorOptions& options = stable_diff .GetOptions(); - options.set_base_seed(0); - options.set_output_image_height(kPluginsOutputSize); - options.set_output_image_width(kPluginsOutputSize); - options.set_file_folder(subgraph_options.text2image_model_directory()); - options.set_show_every_n_iteration(100); - options.set_emit_empty_packet(true); + if (subgraph_options.has_stable_diffusion_iterate_options()) { + options = subgraph_options.stable_diffusion_iterate_options(); + } else { + options.set_base_seed(0); + options.set_output_image_height(kPluginsOutputSize); + options.set_output_image_width(kPluginsOutputSize); + options.set_file_folder(subgraph_options.text2image_model_directory()); + options.set_show_every_n_iteration(100); + options.set_emit_empty_packet(true); + } if (lora_resources.has_value()) { auto& lora_layer_weights_mapping = *options.mutable_lora_weights_layer_mapping(); diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/BUILD b/mediapipe/tasks/cc/vision/image_generator/proto/BUILD index 38e1048cf..971bb7f07 100644 --- a/mediapipe/tasks/cc/vision/image_generator/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_generator/proto/BUILD @@ -48,5 +48,6 @@ mediapipe_proto_library( deps = [ ":control_plugin_graph_options_proto", "//mediapipe/tasks/cc/core/proto:external_file_proto", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto b/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto index 867080dc3..5bbf8de15 100644 --- a/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto3"; package mediapipe.tasks.vision.image_generator.proto; import "mediapipe/tasks/cc/core/proto/external_file.proto"; +import "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto"; import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; @@ -32,4 +33,7 @@ message ImageGeneratorGraphOptions { core.proto.ExternalFile lora_weights_file = 2; repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3; + + mediapipe.StableDiffusionIterateCalculatorOptions + stable_diffusion_iterate_options = 4; } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 0fc4a4974..aa1b91b06 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -67,6 +67,10 @@ _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD index 5c4bb3f95..6c5b2400b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD @@ -67,8 +67,8 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", - "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite", "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java index 1de8e4c46..91218922f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java @@ -19,6 +19,7 @@ import android.graphics.Bitmap; import android.util.Log; import androidx.annotation.Nullable; import com.google.auto.value.AutoValue; +import com.google.mediapipe.calculator.proto.StableDiffusionIterateCalculatorOptionsProto; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.Packet; @@ -28,8 +29,6 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; -import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener; -import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; import com.google.mediapipe.tasks.core.TaskInfo; import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskResult; @@ -64,17 +63,23 @@ public final class ImageGenerator extends BaseVisionTaskApi { private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image"; private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image"; private static final String SELECT_STREAM_NAME = "select"; + private static final String SHOW_RESULT_STREAM_NAME = "show_result"; private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0; private static final int STEPS_OUT_STREAM_INDEX = 1; private static final int ITERATION_OUT_STREAM_INDEX = 2; + private static final int SHOW_RESULT_OUT_STREAM_INDEX = 3; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph"; private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME = "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer"; private static final String TAG = "ImageGenerator"; + private static final int GENERATED_IMAGE_WIDTH = 512; + private static final int GENERATED_IMAGE_HEIGHT = 512; private TaskRunner conditionImageGraphsContainerTaskRunner; private Map conditionTypeIndex; private boolean useConditionImage = false; + private CachedInputs cachedInputs = new CachedInputs(); + private boolean inProcessing = false; /** * Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}. @@ -107,7 +112,8 @@ public final class ImageGenerator extends BaseVisionTaskApi { "STEPS:" + STEPS_STREAM_NAME, "ITERATION:" + ITERATION_STREAM_NAME, "PROMPT:" + PROMPT_STREAM_NAME, - "RAND_SEED:" + RAND_SEED_STREAM_NAME)); + "RAND_SEED:" + RAND_SEED_STREAM_NAME, + "SHOW_RESULT:" + SHOW_RESULT_STREAM_NAME)); final boolean useConditionImage = conditionOptions != null; if (useConditionImage) { inputStreams.add("SELECT:" + SELECT_STREAM_NAME); @@ -115,7 +121,11 @@ public final class ImageGenerator extends BaseVisionTaskApi { generatorOptions.conditionOptions = Optional.of(conditionOptions); } List outputStreams = - Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out"); + Arrays.asList( + "IMAGE:image_out", + "STEPS:steps_out", + "ITERATION:iteration_out", + "SHOW_RESULT:show_result_out"); OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( @@ -125,16 +135,22 @@ public final class ImageGenerator extends BaseVisionTaskApi { public ImageGeneratorResult convertToTaskResult(List packets) { int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX)); int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX)); - Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps); - if (iteration != steps - 1) { + boolean showResult = + PacketGetter.getBool(packets.get(SHOW_RESULT_OUT_STREAM_INDEX)) + || iteration == steps - 1; + Log.i( + "ImageGenerator", + "Iteration: " + iteration + ", Steps: " + steps + ", ShowResult: " + showResult); + if (showResult) { + Log.i("ImageGenerator", "processing generated image"); + Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX); + Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet); + BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap); + return ImageGeneratorResult.create( + bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND); + } else { return null; } - Log.i("ImageGenerator", "processing generated image"); - Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX); - Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet); - BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap); - return ImageGeneratorResult.create( - bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND); } @Override @@ -143,16 +159,6 @@ public final class ImageGenerator extends BaseVisionTaskApi { } }); handler.setHandleTimestampBoundChanges(true); - if (generatorOptions.resultListener().isPresent()) { - ResultListener resultListener = - new ResultListener() { - @Override - public void run(ImageGeneratorResult imageGeneratorResult, Void input) { - generatorOptions.resultListener().get().run(imageGeneratorResult); - } - }; - handler.setResultListener(resultListener); - } generatorOptions.errorListener().ifPresent(handler::setErrorListener); TaskRunner runner = TaskRunner.create( @@ -229,11 +235,19 @@ public final class ImageGenerator extends BaseVisionTaskApi { * Generates an image for iterations and the given random seed. Only valid when the ImageGenerator * is created without condition options. * + *

This is an e2e API, which runs {@code iterations} to generate an image. Consider using the + * iterative API instead to fetch the intermediate results. + * * @param prompt The text prompt describing the image to be generated. * @param iterations The total iterations to generate the image. * @param seed The random seed used during image generation. */ public ImageGeneratorResult generate(String prompt, int iterations, int seed) { + if (useConditionImage) { + throw new IllegalArgumentException( + "ImageGenerator is created with condition options. Must use the methods with condition " + + "options."); + } return runIterations(prompt, iterations, seed, null, 0); } @@ -241,6 +255,9 @@ public final class ImageGenerator extends BaseVisionTaskApi { * Generates an image based on the source image for iterations and the given random seed. Only * valid when the ImageGenerator is created with condition options. * + *

This is an e2e API, which runs {@code iterations} to generate an image. Consider using the + * iterative API instead to fetch the intermediate results. + * * @param prompt The text prompt describing the image to be generated. * @param sourceConditionImage The source image used to create the condition image, which is used * as a guidance for the image generation. @@ -255,6 +272,11 @@ public final class ImageGenerator extends BaseVisionTaskApi { ConditionOptions.ConditionType conditionType, int iterations, int seed) { + if (!useConditionImage) { + throw new IllegalArgumentException( + "ImageGenerator is created without condition options. Must use the methods without " + + "condition options."); + } return runIterations( prompt, iterations, @@ -263,6 +285,97 @@ public final class ImageGenerator extends BaseVisionTaskApi { conditionTypeIndex.get(conditionType)); } + /** + * Sets the inputs of the ImageGenerator. There is {@link setInputs} and {@link execute} method + * pair for iterative usage. Users must call {@link setInputs} before {@link execute}. Only valid + * when the ImageGenerator is created without condition options. + * + * @param prompt The text prompt describing the image to be generated. + * @param iterations The total iterations to generate the image. + * @param seed The random seed used during image generation. + */ + public void setInputs(String prompt, int iterations, int seed) { + if (useConditionImage) { + throw new IllegalArgumentException( + "ImageGenerator is created with condition options. Must use the methods with condition " + + "options."); + } + cachedInputs = new CachedInputs(); + cachedInputs.prompt = prompt; + cachedInputs.iterations = iterations; + cachedInputs.seed = seed; + cachedInputs.step = 0; + inProcessing = true; + } + + /** + * Sets the inputs of the ImageGenerator. For iterative usage, use {@link setInputs} and {@link + * execute} in pairs. Users must call {@link setInputs} before {@link execute}. Only valid when + * the ImageGenerator is created with condition options. + * + * @param prompt The text prompt describing the image to be generated. + * @param sourceConditionImage The source image used to create the condition image, which is used + * as a guidance for the image generation. + * @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of + * condition image. + * @param iterations The total iterations to generate the image. + * @param seed The random seed used during image generation. + */ + public void setInputs( + String prompt, + MPImage sourceConditionImage, + ConditionOptions.ConditionType conditionType, + int iterations, + int seed) { + if (!useConditionImage) { + throw new IllegalArgumentException( + "ImageGenerator is created without condition options. Must use the methods without " + + "condition options."); + } + cachedInputs = new CachedInputs(); + cachedInputs.prompt = prompt; + cachedInputs.iterations = iterations; + cachedInputs.seed = seed; + cachedInputs.step = 0; + cachedInputs.cachedConditionImage = createConditionImage(sourceConditionImage, conditionType); + cachedInputs.conditionType = conditionType; + inProcessing = true; + } + + /** + * Executes one iteration of image generation. The method must be called {@code iterations} times + * to generate the final image. Must call {@link setInputs} before calling this method. + * + *

This is an iterative API, which must be called iteratively. + * + *

This API is useful for showing the intermediate image generation results and the image + * generation progress. Note that requesting the intermediate results will result in a larger + * latency. Consider using the e2e API instead for latency consideration. + * + *

Example usage: + * + *

imageGenerator.setInputs(prompt, iterations, seed); for (int step = 0; step < iterations; + * step++) { ImageGeneratorResult result = imageGenerator.execute(true); } + * + * @param showResult Whether to get the generated image result in the intermediate iterations. If + * false, null is returned. The generated image result is always returned at the last + * iteration, regardless of showResult value. + */ + @Nullable + public ImageGeneratorResult execute(boolean showResult) { + if (!inProcessing) { + throw new IllegalArgumentException("Must call setInputs before execute."); + } + return runStep( + cachedInputs.prompt, + cachedInputs.iterations, + cachedInputs.step++, + cachedInputs.seed, + cachedInputs.cachedConditionImage, + useConditionImage ? conditionTypeIndex.get(cachedInputs.conditionType) : 0, + showResult); + } + /** * Create the condition image of specified condition type from the source image. Currently support * face landmarks, depth image and edge image as the condition image. @@ -295,6 +408,11 @@ public final class ImageGenerator extends BaseVisionTaskApi { private ImageGeneratorResult runIterations( String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) { + if (inProcessing) { + throw new IllegalArgumentException( + "Iterative API was called previously. It is not allowed to called batch API during" + + "iterative processing."); + } ImageGeneratorResult result = null; long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND; for (int i = 0; i < steps; i++) { @@ -318,11 +436,84 @@ public final class ImageGenerator extends BaseVisionTaskApi { return result; } + @Nullable + private ImageGeneratorResult runStep( + String prompt, + int iterations, + int step, + int seed, + @Nullable MPImage conditionImage, + int select, + boolean showResult) { + if (step == 0) { + cachedInputs.cachedTimestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND; + } + Map inputPackets = new HashMap<>(); + if (step == 0 && useConditionImage) { + inputPackets.put( + CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage)); + inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select)); + } + inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt)); + inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(iterations)); + inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(step)); + inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed)); + inputPackets.put(SHOW_RESULT_STREAM_NAME, runner.getPacketCreator().createBool(showResult)); + ImageGeneratorResult result = + (ImageGeneratorResult) runner.process(inputPackets, cachedInputs.cachedTimestamp++); + if (result != null && useConditionImage) { + // Add condition image to the ImageGeneratorResult. + result = + ImageGeneratorResult.create( + result.generatedImage(), conditionImage, result.timestampMs()); + } + if (step == iterations - 1) { + inProcessing = false; + cachedInputs = new CachedInputs(); + } + return result; + } + /** Closes and cleans up the task runners. */ @Override public void close() { - runner.close(); - conditionImageGraphsContainerTaskRunner.close(); + if (runner != null) { + runner.close(); + } + if (conditionImageGraphsContainerTaskRunner != null) { + conditionImageGraphsContainerTaskRunner.close(); + } + } + + // Helper class to holder inputs to be checked with the inputs of next step. + private static class CachedInputs { + public CachedInputs() { + this.prompt = ""; + this.iterations = 0; + this.step = 0; + this.seed = 0; + } + + @Override + public final String toString() { + return "Prompt: " + + prompt + + ", Iterations: " + + iterations + + ", Step: " + + step + + ", Seed: " + + seed + + (conditionType == null ? "" : ", ConditionType: " + conditionType.name()); + } + + public String prompt; + public int iterations; + public int step; + public int seed; + public ConditionOptions.ConditionType conditionType; + public MPImage cachedConditionImage; + public long cachedTimestamp; } /** A container class for the condition image. */ @@ -343,15 +534,12 @@ public final class ImageGenerator extends BaseVisionTaskApi { @AutoValue.Builder public abstract static class Builder { - /** Sets the text to image model directory storing the model weights. */ - public abstract Builder setText2ImageModelDirectory(String modelDirectory); + /** Sets the image generator model directory storing the model weights. */ + public abstract Builder setImageGeneratorModelDirectory(String modelDirectory); /** Sets the path to LoRA weights file. */ public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath); - public abstract Builder setResultListener( - PureResultListener resultListener); - /** Sets an optional {@link ErrorListener}}. */ public abstract Builder setErrorListener(ErrorListener value); @@ -363,19 +551,17 @@ public final class ImageGenerator extends BaseVisionTaskApi { } } - abstract String text2ImageModelDirectory(); + abstract String imageGeneratorModelDirectory(); abstract Optional loraWeightsFilePath(); - abstract Optional> resultListener(); - abstract Optional errorListener(); private Optional conditionOptions; public static Builder builder() { return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder() - .setText2ImageModelDirectory(""); + .setImageGeneratorModelDirectory(""); } /** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */ @@ -393,13 +579,23 @@ public final class ImageGenerator extends BaseVisionTaskApi { e.printStackTrace(); } } - taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory()); + taskOptionsBuilder.setText2ImageModelDirectory(imageGeneratorModelDirectory()); if (loraWeightsFilePath().isPresent()) { ExternalFileProto.ExternalFile.Builder externalFileBuilder = ExternalFileProto.ExternalFile.newBuilder(); externalFileBuilder.setFileName(loraWeightsFilePath().get()); taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build()); } + taskOptionsBuilder.setStableDiffusionIterateOptions( + StableDiffusionIterateCalculatorOptionsProto.StableDiffusionIterateCalculatorOptions + .newBuilder() + .setBaseSeed(0) + .setFileFolder(imageGeneratorModelDirectory()) + .setOutputImageWidth(GENERATED_IMAGE_WIDTH) + .setOutputImageHeight(GENERATED_IMAGE_HEIGHT) + .setEmitEmptyPacket(true) + .setShowEveryNIteration(100) + .build()); return Any.newBuilder() .setTypeUrl( "type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions") @@ -465,7 +661,8 @@ public final class ImageGenerator extends BaseVisionTaskApi { taskOptionsBuilder.addControlPluginGraphsOptions( ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() .setBaseOptions( - convertBaseOptionsToProto(faceConditionOptions().get().baseOptions())) + convertBaseOptionsToProto( + faceConditionOptions().get().pluginModelBaseOptions())) .setConditionedImageGraphOptions( ConditionedImageGraphOptions.newBuilder() .setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto()) @@ -476,7 +673,8 @@ public final class ImageGenerator extends BaseVisionTaskApi { taskOptionsBuilder.addControlPluginGraphsOptions( ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() .setBaseOptions( - convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions())) + convertBaseOptionsToProto( + edgeConditionOptions().get().pluginModelBaseOptions())) .setConditionedImageGraphOptions( ConditionedImageGraphOptions.newBuilder() .setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto()) @@ -486,7 +684,8 @@ public final class ImageGenerator extends BaseVisionTaskApi { taskOptionsBuilder.addControlPluginGraphsOptions( ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() .setBaseOptions( - convertBaseOptionsToProto(depthConditionOptions().get().baseOptions())) + convertBaseOptionsToProto( + depthConditionOptions().get().pluginModelBaseOptions())) .setConditionedImageGraphOptions( ConditionedImageGraphOptions.newBuilder() .setDepthConditionTypeOptions( @@ -510,11 +709,16 @@ public final class ImageGenerator extends BaseVisionTaskApi { @AutoValue.Builder public abstract static class Builder { /** Set the base options for plugin model. */ - public abstract Builder setBaseOptions(BaseOptions baseOptions); + public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions); - /* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */ - public abstract Builder setFaceLandmarkerOptions( - FaceLandmarkerOptions faceLandmarkerOptions); + /** Set base options for face landmarks model. */ + public abstract Builder setFaceModelBaseOptions(BaseOptions baseOptions); + + /** Set minimum confidence score of face presence score in the face landmark detection. */ + public abstract Builder setMinFaceDetectionConfidence(float minFaceDetectionConfidence); + + /** Set the face presence threshold */ + public abstract Builder setMinFacePresenceConfidence(float minFacePresenceConfidence); abstract FaceConditionOptions autoBuild(); @@ -524,23 +728,36 @@ public final class ImageGenerator extends BaseVisionTaskApi { } } - abstract BaseOptions baseOptions(); + abstract BaseOptions pluginModelBaseOptions(); - abstract FaceLandmarkerOptions faceLandmarkerOptions(); + abstract BaseOptions faceModelBaseOptions(); + + abstract float minFaceDetectionConfidence(); + + abstract float minFacePresenceConfidence(); public static Builder builder() { - return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder(); + return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder() + .setMinFaceDetectionConfidence(0.5f) + .setMinFacePresenceConfidence(0.5f); } ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() { + FaceLandmarkerOptions faceLandmarkerOptions = + FaceLandmarkerOptions.builder() + .setBaseOptions(faceModelBaseOptions()) + .setMinFaceDetectionConfidence(minFaceDetectionConfidence()) + .setMinFacePresenceConfidence(minFacePresenceConfidence()) + .setRunningMode(RunningMode.IMAGE) + .setOutputFaceBlendshapes(false) + .setOutputFacialTransformationMatrixes(false) + .setNumFaces(1) + .build(); return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder() .setFaceLandmarkerGraphOptions( - FaceLandmarkerGraphOptions.newBuilder() - .mergeFrom( - faceLandmarkerOptions() - .convertToCalculatorOptionsProto() - .getExtension(FaceLandmarkerGraphOptions.ext)) - .build()) + faceLandmarkerOptions + .convertToCalculatorOptionsProto() + .getExtension(FaceLandmarkerGraphOptions.ext)) .build(); } } @@ -553,12 +770,11 @@ public final class ImageGenerator extends BaseVisionTaskApi { @AutoValue.Builder public abstract static class Builder { - /** Set the base options for plugin model. */ - public abstract Builder setBaseOptions(BaseOptions baseOptions); + /** Set the base options for the plugin model. */ + public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions); - /** {@link ImageSegmenterOptions} used to detect depth image from the source image. */ - public abstract Builder setImageSegmenterOptions( - ImageSegmenterOptions imageSegmenterOptions); + /** Set the base options for the depth model. */ + public abstract Builder setDepthModelBaseOptions(BaseOptions baseOptions); abstract DepthConditionOptions autoBuild(); @@ -569,18 +785,25 @@ public final class ImageGenerator extends BaseVisionTaskApi { } } - abstract BaseOptions baseOptions(); + abstract BaseOptions pluginModelBaseOptions(); - abstract ImageSegmenterOptions imageSegmenterOptions(); + abstract BaseOptions depthModelBaseOptions(); public static Builder builder() { return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder(); } ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() { + ImageSegmenterOptions imageSegmenterOptions = + ImageSegmenterOptions.builder() + .setBaseOptions(depthModelBaseOptions()) + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false) + .setRunningMode(RunningMode.IMAGE) + .build(); return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder() .setImageSegmenterGraphOptions( - imageSegmenterOptions() + imageSegmenterOptions .convertToCalculatorOptionsProto() .getExtension(ImageSegmenterGraphOptions.ext)) .build(); @@ -603,7 +826,7 @@ public final class ImageGenerator extends BaseVisionTaskApi { public abstract static class Builder { /** Set the base options for plugin model. */ - public abstract Builder setBaseOptions(BaseOptions baseOptions); + public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions); /** First threshold for the hysteresis procedure. */ public abstract Builder setThreshold1(Float threshold1); @@ -629,7 +852,7 @@ public final class ImageGenerator extends BaseVisionTaskApi { } } - abstract BaseOptions baseOptions(); + abstract BaseOptions pluginModelBaseOptions(); abstract Float threshold1();