Provide API/options to show intermediate results and generating progress for Java Image Generator.

PiperOrigin-RevId: 562014712
This commit is contained in:
MediaPipe Team 2023-09-01 12:04:02 -07:00 committed by Copybara-Service
parent ceb8cd3c78
commit 6c43d37e5a
6 changed files with 323 additions and 67 deletions

View File

@ -68,6 +68,7 @@ constexpr absl::string_view kRandSeedTag = "RAND_SEED";
constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS"; constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS";
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE"; constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
constexpr absl::string_view kSelectTag = "SELECT"; constexpr absl::string_view kSelectTag = "SELECT";
constexpr absl::string_view kShowResultTag = "SHOW_RESULT";
constexpr absl::string_view kMetadataFilename = "metadata"; constexpr absl::string_view kMetadataFilename = "metadata";
constexpr absl::string_view kLoraRankStr = "lora_rank"; constexpr absl::string_view kLoraRankStr = "lora_rank";
@ -78,6 +79,7 @@ struct ImageGeneratorInputs {
Source<int> rand_seed; Source<int> rand_seed;
std::optional<Source<Image>> condition_image; std::optional<Source<Image>> condition_image;
std::optional<Source<int>> select_condition_type; std::optional<Source<int>> select_condition_type;
std::optional<Source<bool>> show_result;
}; };
struct ImageGeneratorOutputs { struct ImageGeneratorOutputs {
@ -209,6 +211,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// valid, if condtrol plugin graph options are set in the graph options. // valid, if condtrol plugin graph options are set in the graph options.
// SELECT - int // SELECT - int
// The index of the selected the control plugin graph. // The index of the selected the control plugin graph.
// SHOW_RESULT - bool @Optional
// Whether to show the diffusion result at the current step. If this stream
// is not empty, regardless show_every_n_iteration in the options.
// //
// Outputs: // Outputs:
// IMAGE - Image // IMAGE - Image
@ -218,6 +223,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// ITERATION - int @optional // ITERATION - int @optional
// The current iteration in the generating steps. The same as ITERATION // The current iteration in the generating steps. The same as ITERATION
// input. // input.
// SHOW_RESULT - bool @Optional
// Whether to show the diffusion result at the current step. The same as
// input SHOW_RESULT.
class ImageGeneratorGraph : public core::ModelTaskGraph { class ImageGeneratorGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
@ -239,6 +247,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
condition_image = graph.In(kConditionImageTag).Cast<Image>(); condition_image = graph.In(kConditionImageTag).Cast<Image>();
select_condition_type = graph.In(kSelectTag).Cast<int>(); select_condition_type = graph.In(kSelectTag).Cast<int>();
} }
std::optional<Source<bool>> show_result;
if (HasInput(sc->OriginalNode(), kShowResultTag)) {
show_result = graph.In(kShowResultTag).Cast<bool>();
}
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto outputs, auto outputs,
BuildImageGeneratorGraph( BuildImageGeneratorGraph(
@ -251,6 +263,7 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
/*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(), /*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(),
/*condition_image*/ condition_image, /*condition_image*/ condition_image,
/*select_condition_type*/ select_condition_type, /*select_condition_type*/ select_condition_type,
/*show_result*/ show_result,
}, },
graph)); graph));
outputs.generated_image >> graph.Out(kImageTag).Cast<Image>(); outputs.generated_image >> graph.Out(kImageTag).Cast<Image>();
@ -261,6 +274,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
graph.In(kStepsTag) >> pass_through.In(1); graph.In(kStepsTag) >> pass_through.In(1);
pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)]; pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)];
pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)]; pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)];
if (HasOutput(sc->OriginalNode(), kShowResultTag)) {
graph.In(kShowResultTag) >> pass_through.In(2);
pass_through.Out(2) >> graph[Output<bool>::Optional(kShowResultTag)];
}
return graph.GetConfig(); return graph.GetConfig();
} }
@ -299,15 +316,22 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
inputs.steps >> stable_diff.In(kStepsTag); inputs.steps >> stable_diff.In(kStepsTag);
inputs.iteration >> stable_diff.In(kIterationTag); inputs.iteration >> stable_diff.In(kIterationTag);
inputs.rand_seed >> stable_diff.In(kRandSeedTag); inputs.rand_seed >> stable_diff.In(kRandSeedTag);
if (inputs.show_result.has_value()) {
*inputs.show_result >> stable_diff.In(kShowResultTag);
}
mediapipe::StableDiffusionIterateCalculatorOptions& options = mediapipe::StableDiffusionIterateCalculatorOptions& options =
stable_diff stable_diff
.GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>(); .GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>();
if (subgraph_options.has_stable_diffusion_iterate_options()) {
options = subgraph_options.stable_diffusion_iterate_options();
} else {
options.set_base_seed(0); options.set_base_seed(0);
options.set_output_image_height(kPluginsOutputSize); options.set_output_image_height(kPluginsOutputSize);
options.set_output_image_width(kPluginsOutputSize); options.set_output_image_width(kPluginsOutputSize);
options.set_file_folder(subgraph_options.text2image_model_directory()); options.set_file_folder(subgraph_options.text2image_model_directory());
options.set_show_every_n_iteration(100); options.set_show_every_n_iteration(100);
options.set_emit_empty_packet(true); options.set_emit_empty_packet(true);
}
if (lora_resources.has_value()) { if (lora_resources.has_value()) {
auto& lora_layer_weights_mapping = auto& lora_layer_weights_mapping =
*options.mutable_lora_weights_layer_mapping(); *options.mutable_lora_weights_layer_mapping();

View File

@ -48,5 +48,6 @@ mediapipe_proto_library(
deps = [ deps = [
":control_plugin_graph_options_proto", ":control_plugin_graph_options_proto",
"//mediapipe/tasks/cc/core/proto:external_file_proto", "//mediapipe/tasks/cc/core/proto:external_file_proto",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_proto",
], ],
) )

View File

@ -18,6 +18,7 @@ syntax = "proto3";
package mediapipe.tasks.vision.image_generator.proto; package mediapipe.tasks.vision.image_generator.proto;
import "mediapipe/tasks/cc/core/proto/external_file.proto"; import "mediapipe/tasks/cc/core/proto/external_file.proto";
import "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto";
import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto"; import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
@ -32,4 +33,7 @@ message ImageGeneratorGraphOptions {
core.proto.ExternalFile lora_weights_file = 2; core.proto.ExternalFile lora_weights_file = 2;
repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3; repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3;
mediapipe.StableDiffusionIterateCalculatorOptions
stable_diffusion_iterate_options = 4;
} }

View File

@ -67,6 +67,10 @@ _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",

View File

@ -67,8 +67,8 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",

View File

@ -19,6 +19,7 @@ import android.graphics.Bitmap;
import android.util.Log; import android.util.Log;
import androidx.annotation.Nullable; import androidx.annotation.Nullable;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.calculator.proto.StableDiffusionIterateCalculatorOptionsProto;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
@ -28,8 +29,6 @@ import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo; import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
@ -64,17 +63,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image"; private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image"; private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
private static final String SELECT_STREAM_NAME = "select"; private static final String SELECT_STREAM_NAME = "select";
private static final String SHOW_RESULT_STREAM_NAME = "show_result";
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0; private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
private static final int STEPS_OUT_STREAM_INDEX = 1; private static final int STEPS_OUT_STREAM_INDEX = 1;
private static final int ITERATION_OUT_STREAM_INDEX = 2; private static final int ITERATION_OUT_STREAM_INDEX = 2;
private static final int SHOW_RESULT_OUT_STREAM_INDEX = 3;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph"; "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME = private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer"; "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
private static final String TAG = "ImageGenerator"; private static final String TAG = "ImageGenerator";
private static final int GENERATED_IMAGE_WIDTH = 512;
private static final int GENERATED_IMAGE_HEIGHT = 512;
private TaskRunner conditionImageGraphsContainerTaskRunner; private TaskRunner conditionImageGraphsContainerTaskRunner;
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex; private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
private boolean useConditionImage = false; private boolean useConditionImage = false;
private CachedInputs cachedInputs = new CachedInputs();
private boolean inProcessing = false;
/** /**
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}. * Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
@ -107,7 +112,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
"STEPS:" + STEPS_STREAM_NAME, "STEPS:" + STEPS_STREAM_NAME,
"ITERATION:" + ITERATION_STREAM_NAME, "ITERATION:" + ITERATION_STREAM_NAME,
"PROMPT:" + PROMPT_STREAM_NAME, "PROMPT:" + PROMPT_STREAM_NAME,
"RAND_SEED:" + RAND_SEED_STREAM_NAME)); "RAND_SEED:" + RAND_SEED_STREAM_NAME,
"SHOW_RESULT:" + SHOW_RESULT_STREAM_NAME));
final boolean useConditionImage = conditionOptions != null; final boolean useConditionImage = conditionOptions != null;
if (useConditionImage) { if (useConditionImage) {
inputStreams.add("SELECT:" + SELECT_STREAM_NAME); inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
@ -115,7 +121,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
generatorOptions.conditionOptions = Optional.of(conditionOptions); generatorOptions.conditionOptions = Optional.of(conditionOptions);
} }
List<String> outputStreams = List<String> outputStreams =
Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out"); Arrays.asList(
"IMAGE:image_out",
"STEPS:steps_out",
"ITERATION:iteration_out",
"SHOW_RESULT:show_result_out");
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>(); OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
@ -125,16 +135,22 @@ public final class ImageGenerator extends BaseVisionTaskApi {
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) { public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX)); int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX)); int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps); boolean showResult =
if (iteration != steps - 1) { PacketGetter.getBool(packets.get(SHOW_RESULT_OUT_STREAM_INDEX))
return null; || iteration == steps - 1;
} Log.i(
"ImageGenerator",
"Iteration: " + iteration + ", Steps: " + steps + ", ShowResult: " + showResult);
if (showResult) {
Log.i("ImageGenerator", "processing generated image"); Log.i("ImageGenerator", "processing generated image");
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX); Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet); Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap); BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
return ImageGeneratorResult.create( return ImageGeneratorResult.create(
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND); bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
} else {
return null;
}
} }
@Override @Override
@ -143,16 +159,6 @@ public final class ImageGenerator extends BaseVisionTaskApi {
} }
}); });
handler.setHandleTimestampBoundChanges(true); handler.setHandleTimestampBoundChanges(true);
if (generatorOptions.resultListener().isPresent()) {
ResultListener<ImageGeneratorResult, Void> resultListener =
new ResultListener<ImageGeneratorResult, Void>() {
@Override
public void run(ImageGeneratorResult imageGeneratorResult, Void input) {
generatorOptions.resultListener().get().run(imageGeneratorResult);
}
};
handler.setResultListener(resultListener);
}
generatorOptions.errorListener().ifPresent(handler::setErrorListener); generatorOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner = TaskRunner runner =
TaskRunner.create( TaskRunner.create(
@ -229,11 +235,19 @@ public final class ImageGenerator extends BaseVisionTaskApi {
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator * Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
* is created without condition options. * is created without condition options.
* *
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
* iterative API instead to fetch the intermediate results.
*
* @param prompt The text prompt describing the image to be generated. * @param prompt The text prompt describing the image to be generated.
* @param iterations The total iterations to generate the image. * @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation. * @param seed The random seed used during image generation.
*/ */
public ImageGeneratorResult generate(String prompt, int iterations, int seed) { public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
if (useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created with condition options. Must use the methods with condition "
+ "options.");
}
return runIterations(prompt, iterations, seed, null, 0); return runIterations(prompt, iterations, seed, null, 0);
} }
@ -241,6 +255,9 @@ public final class ImageGenerator extends BaseVisionTaskApi {
* Generates an image based on the source image for iterations and the given random seed. Only * Generates an image based on the source image for iterations and the given random seed. Only
* valid when the ImageGenerator is created with condition options. * valid when the ImageGenerator is created with condition options.
* *
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
* iterative API instead to fetch the intermediate results.
*
* @param prompt The text prompt describing the image to be generated. * @param prompt The text prompt describing the image to be generated.
* @param sourceConditionImage The source image used to create the condition image, which is used * @param sourceConditionImage The source image used to create the condition image, which is used
* as a guidance for the image generation. * as a guidance for the image generation.
@ -255,6 +272,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
ConditionOptions.ConditionType conditionType, ConditionOptions.ConditionType conditionType,
int iterations, int iterations,
int seed) { int seed) {
if (!useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created without condition options. Must use the methods without "
+ "condition options.");
}
return runIterations( return runIterations(
prompt, prompt,
iterations, iterations,
@ -263,6 +285,97 @@ public final class ImageGenerator extends BaseVisionTaskApi {
conditionTypeIndex.get(conditionType)); conditionTypeIndex.get(conditionType));
} }
/**
* Sets the inputs of the ImageGenerator. There is {@link setInputs} and {@link execute} method
* pair for iterative usage. Users must call {@link setInputs} before {@link execute}. Only valid
* when the ImageGenerator is created without condition options.
*
* @param prompt The text prompt describing the image to be generated.
* @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation.
*/
public void setInputs(String prompt, int iterations, int seed) {
if (useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created with condition options. Must use the methods with condition "
+ "options.");
}
cachedInputs = new CachedInputs();
cachedInputs.prompt = prompt;
cachedInputs.iterations = iterations;
cachedInputs.seed = seed;
cachedInputs.step = 0;
inProcessing = true;
}
/**
* Sets the inputs of the ImageGenerator. For iterative usage, use {@link setInputs} and {@link
* execute} in pairs. Users must call {@link setInputs} before {@link execute}. Only valid when
* the ImageGenerator is created with condition options.
*
* @param prompt The text prompt describing the image to be generated.
* @param sourceConditionImage The source image used to create the condition image, which is used
* as a guidance for the image generation.
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
* condition image.
* @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation.
*/
public void setInputs(
String prompt,
MPImage sourceConditionImage,
ConditionOptions.ConditionType conditionType,
int iterations,
int seed) {
if (!useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created without condition options. Must use the methods without "
+ "condition options.");
}
cachedInputs = new CachedInputs();
cachedInputs.prompt = prompt;
cachedInputs.iterations = iterations;
cachedInputs.seed = seed;
cachedInputs.step = 0;
cachedInputs.cachedConditionImage = createConditionImage(sourceConditionImage, conditionType);
cachedInputs.conditionType = conditionType;
inProcessing = true;
}
/**
* Executes one iteration of image generation. The method must be called {@code iterations} times
* to generate the final image. Must call {@link setInputs} before calling this method.
*
* <p>This is an iterative API, which must be called iteratively.
*
* <p>This API is useful for showing the intermediate image generation results and the image
* generation progress. Note that requesting the intermediate results will result in a larger
* latency. Consider using the e2e API instead for latency consideration.
*
* <p>Example usage:
*
* <p>imageGenerator.setInputs(prompt, iterations, seed); for (int step = 0; step < iterations;
* step++) { ImageGeneratorResult result = imageGenerator.execute(true); }
*
* @param showResult Whether to get the generated image result in the intermediate iterations. If
* false, null is returned. The generated image result is always returned at the last
* iteration, regardless of showResult value.
*/
@Nullable
public ImageGeneratorResult execute(boolean showResult) {
if (!inProcessing) {
throw new IllegalArgumentException("Must call setInputs before execute.");
}
return runStep(
cachedInputs.prompt,
cachedInputs.iterations,
cachedInputs.step++,
cachedInputs.seed,
cachedInputs.cachedConditionImage,
useConditionImage ? conditionTypeIndex.get(cachedInputs.conditionType) : 0,
showResult);
}
/** /**
* Create the condition image of specified condition type from the source image. Currently support * Create the condition image of specified condition type from the source image. Currently support
* face landmarks, depth image and edge image as the condition image. * face landmarks, depth image and edge image as the condition image.
@ -295,6 +408,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
private ImageGeneratorResult runIterations( private ImageGeneratorResult runIterations(
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) { String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
if (inProcessing) {
throw new IllegalArgumentException(
"Iterative API was called previously. It is not allowed to called batch API during"
+ "iterative processing.");
}
ImageGeneratorResult result = null; ImageGeneratorResult result = null;
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND; long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
@ -318,12 +436,85 @@ public final class ImageGenerator extends BaseVisionTaskApi {
return result; return result;
} }
@Nullable
private ImageGeneratorResult runStep(
String prompt,
int iterations,
int step,
int seed,
@Nullable MPImage conditionImage,
int select,
boolean showResult) {
if (step == 0) {
cachedInputs.cachedTimestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
}
Map<String, Packet> inputPackets = new HashMap<>();
if (step == 0 && useConditionImage) {
inputPackets.put(
CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage));
inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select));
}
inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt));
inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(iterations));
inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(step));
inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed));
inputPackets.put(SHOW_RESULT_STREAM_NAME, runner.getPacketCreator().createBool(showResult));
ImageGeneratorResult result =
(ImageGeneratorResult) runner.process(inputPackets, cachedInputs.cachedTimestamp++);
if (result != null && useConditionImage) {
// Add condition image to the ImageGeneratorResult.
result =
ImageGeneratorResult.create(
result.generatedImage(), conditionImage, result.timestampMs());
}
if (step == iterations - 1) {
inProcessing = false;
cachedInputs = new CachedInputs();
}
return result;
}
/** Closes and cleans up the task runners. */ /** Closes and cleans up the task runners. */
@Override @Override
public void close() { public void close() {
if (runner != null) {
runner.close(); runner.close();
}
if (conditionImageGraphsContainerTaskRunner != null) {
conditionImageGraphsContainerTaskRunner.close(); conditionImageGraphsContainerTaskRunner.close();
} }
}
// Helper class to holder inputs to be checked with the inputs of next step.
private static class CachedInputs {
public CachedInputs() {
this.prompt = "";
this.iterations = 0;
this.step = 0;
this.seed = 0;
}
@Override
public final String toString() {
return "Prompt: "
+ prompt
+ ", Iterations: "
+ iterations
+ ", Step: "
+ step
+ ", Seed: "
+ seed
+ (conditionType == null ? "" : ", ConditionType: " + conditionType.name());
}
public String prompt;
public int iterations;
public int step;
public int seed;
public ConditionOptions.ConditionType conditionType;
public MPImage cachedConditionImage;
public long cachedTimestamp;
}
/** A container class for the condition image. */ /** A container class for the condition image. */
@AutoValue @AutoValue
@ -343,15 +534,12 @@ public final class ImageGenerator extends BaseVisionTaskApi {
@AutoValue.Builder @AutoValue.Builder
public abstract static class Builder { public abstract static class Builder {
/** Sets the text to image model directory storing the model weights. */ /** Sets the image generator model directory storing the model weights. */
public abstract Builder setText2ImageModelDirectory(String modelDirectory); public abstract Builder setImageGeneratorModelDirectory(String modelDirectory);
/** Sets the path to LoRA weights file. */ /** Sets the path to LoRA weights file. */
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath); public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
public abstract Builder setResultListener(
PureResultListener<ImageGeneratorResult> resultListener);
/** Sets an optional {@link ErrorListener}}. */ /** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value); public abstract Builder setErrorListener(ErrorListener value);
@ -363,19 +551,17 @@ public final class ImageGenerator extends BaseVisionTaskApi {
} }
} }
abstract String text2ImageModelDirectory(); abstract String imageGeneratorModelDirectory();
abstract Optional<String> loraWeightsFilePath(); abstract Optional<String> loraWeightsFilePath();
abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();
private Optional<ConditionOptions> conditionOptions; private Optional<ConditionOptions> conditionOptions;
public static Builder builder() { public static Builder builder() {
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder() return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
.setText2ImageModelDirectory(""); .setImageGeneratorModelDirectory("");
} }
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */ /** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
@ -393,13 +579,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
e.printStackTrace(); e.printStackTrace();
} }
} }
taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory()); taskOptionsBuilder.setText2ImageModelDirectory(imageGeneratorModelDirectory());
if (loraWeightsFilePath().isPresent()) { if (loraWeightsFilePath().isPresent()) {
ExternalFileProto.ExternalFile.Builder externalFileBuilder = ExternalFileProto.ExternalFile.Builder externalFileBuilder =
ExternalFileProto.ExternalFile.newBuilder(); ExternalFileProto.ExternalFile.newBuilder();
externalFileBuilder.setFileName(loraWeightsFilePath().get()); externalFileBuilder.setFileName(loraWeightsFilePath().get());
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build()); taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
} }
taskOptionsBuilder.setStableDiffusionIterateOptions(
StableDiffusionIterateCalculatorOptionsProto.StableDiffusionIterateCalculatorOptions
.newBuilder()
.setBaseSeed(0)
.setFileFolder(imageGeneratorModelDirectory())
.setOutputImageWidth(GENERATED_IMAGE_WIDTH)
.setOutputImageHeight(GENERATED_IMAGE_HEIGHT)
.setEmitEmptyPacket(true)
.setShowEveryNIteration(100)
.build());
return Any.newBuilder() return Any.newBuilder()
.setTypeUrl( .setTypeUrl(
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions") "type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
@ -465,7 +661,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
taskOptionsBuilder.addControlPluginGraphsOptions( taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions( .setBaseOptions(
convertBaseOptionsToProto(faceConditionOptions().get().baseOptions())) convertBaseOptionsToProto(
faceConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions( .setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder() ConditionedImageGraphOptions.newBuilder()
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto()) .setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
@ -476,7 +673,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
taskOptionsBuilder.addControlPluginGraphsOptions( taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions( .setBaseOptions(
convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions())) convertBaseOptionsToProto(
edgeConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions( .setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder() ConditionedImageGraphOptions.newBuilder()
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto()) .setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
@ -486,7 +684,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
taskOptionsBuilder.addControlPluginGraphsOptions( taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions( .setBaseOptions(
convertBaseOptionsToProto(depthConditionOptions().get().baseOptions())) convertBaseOptionsToProto(
depthConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions( .setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder() ConditionedImageGraphOptions.newBuilder()
.setDepthConditionTypeOptions( .setDepthConditionTypeOptions(
@ -510,11 +709,16 @@ public final class ImageGenerator extends BaseVisionTaskApi {
@AutoValue.Builder @AutoValue.Builder
public abstract static class Builder { public abstract static class Builder {
/** Set the base options for plugin model. */ /** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions); public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
/* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */ /** Set base options for face landmarks model. */
public abstract Builder setFaceLandmarkerOptions( public abstract Builder setFaceModelBaseOptions(BaseOptions baseOptions);
FaceLandmarkerOptions faceLandmarkerOptions);
/** Set minimum confidence score of face presence score in the face landmark detection. */
public abstract Builder setMinFaceDetectionConfidence(float minFaceDetectionConfidence);
/** Set the face presence threshold */
public abstract Builder setMinFacePresenceConfidence(float minFacePresenceConfidence);
abstract FaceConditionOptions autoBuild(); abstract FaceConditionOptions autoBuild();
@ -524,23 +728,36 @@ public final class ImageGenerator extends BaseVisionTaskApi {
} }
} }
abstract BaseOptions baseOptions(); abstract BaseOptions pluginModelBaseOptions();
abstract FaceLandmarkerOptions faceLandmarkerOptions(); abstract BaseOptions faceModelBaseOptions();
abstract float minFaceDetectionConfidence();
abstract float minFacePresenceConfidence();
public static Builder builder() { public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder(); return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder()
.setMinFaceDetectionConfidence(0.5f)
.setMinFacePresenceConfidence(0.5f);
} }
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() { ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
FaceLandmarkerOptions faceLandmarkerOptions =
FaceLandmarkerOptions.builder()
.setBaseOptions(faceModelBaseOptions())
.setMinFaceDetectionConfidence(minFaceDetectionConfidence())
.setMinFacePresenceConfidence(minFacePresenceConfidence())
.setRunningMode(RunningMode.IMAGE)
.setOutputFaceBlendshapes(false)
.setOutputFacialTransformationMatrixes(false)
.setNumFaces(1)
.build();
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder() return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
.setFaceLandmarkerGraphOptions( .setFaceLandmarkerGraphOptions(
FaceLandmarkerGraphOptions.newBuilder() faceLandmarkerOptions
.mergeFrom(
faceLandmarkerOptions()
.convertToCalculatorOptionsProto() .convertToCalculatorOptionsProto()
.getExtension(FaceLandmarkerGraphOptions.ext)) .getExtension(FaceLandmarkerGraphOptions.ext))
.build())
.build(); .build();
} }
} }
@ -553,12 +770,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
@AutoValue.Builder @AutoValue.Builder
public abstract static class Builder { public abstract static class Builder {
/** Set the base options for plugin model. */ /** Set the base options for the plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions); public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
/** {@link ImageSegmenterOptions} used to detect depth image from the source image. */ /** Set the base options for the depth model. */
public abstract Builder setImageSegmenterOptions( public abstract Builder setDepthModelBaseOptions(BaseOptions baseOptions);
ImageSegmenterOptions imageSegmenterOptions);
abstract DepthConditionOptions autoBuild(); abstract DepthConditionOptions autoBuild();
@ -569,18 +785,25 @@ public final class ImageGenerator extends BaseVisionTaskApi {
} }
} }
abstract BaseOptions baseOptions(); abstract BaseOptions pluginModelBaseOptions();
abstract ImageSegmenterOptions imageSegmenterOptions(); abstract BaseOptions depthModelBaseOptions();
public static Builder builder() { public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder(); return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
} }
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() { ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
ImageSegmenterOptions imageSegmenterOptions =
ImageSegmenterOptions.builder()
.setBaseOptions(depthModelBaseOptions())
.setOutputConfidenceMasks(true)
.setOutputCategoryMask(false)
.setRunningMode(RunningMode.IMAGE)
.build();
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder() return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
.setImageSegmenterGraphOptions( .setImageSegmenterGraphOptions(
imageSegmenterOptions() imageSegmenterOptions
.convertToCalculatorOptionsProto() .convertToCalculatorOptionsProto()
.getExtension(ImageSegmenterGraphOptions.ext)) .getExtension(ImageSegmenterGraphOptions.ext))
.build(); .build();
@ -603,7 +826,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
public abstract static class Builder { public abstract static class Builder {
/** Set the base options for plugin model. */ /** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions); public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
/** First threshold for the hysteresis procedure. */ /** First threshold for the hysteresis procedure. */
public abstract Builder setThreshold1(Float threshold1); public abstract Builder setThreshold1(Float threshold1);
@ -629,7 +852,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
} }
} }
abstract BaseOptions baseOptions(); abstract BaseOptions pluginModelBaseOptions();
abstract Float threshold1(); abstract Float threshold1();