Provide API/options to show intermediate results and generating progress for Java Image Generator.
PiperOrigin-RevId: 562014712
This commit is contained in:
parent
ceb8cd3c78
commit
6c43d37e5a
|
@ -68,6 +68,7 @@ constexpr absl::string_view kRandSeedTag = "RAND_SEED";
|
|||
constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS";
|
||||
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
|
||||
constexpr absl::string_view kSelectTag = "SELECT";
|
||||
constexpr absl::string_view kShowResultTag = "SHOW_RESULT";
|
||||
constexpr absl::string_view kMetadataFilename = "metadata";
|
||||
constexpr absl::string_view kLoraRankStr = "lora_rank";
|
||||
|
||||
|
@ -78,6 +79,7 @@ struct ImageGeneratorInputs {
|
|||
Source<int> rand_seed;
|
||||
std::optional<Source<Image>> condition_image;
|
||||
std::optional<Source<int>> select_condition_type;
|
||||
std::optional<Source<bool>> show_result;
|
||||
};
|
||||
|
||||
struct ImageGeneratorOutputs {
|
||||
|
@ -209,6 +211,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// valid, if condtrol plugin graph options are set in the graph options.
|
||||
// SELECT - int
|
||||
// The index of the selected the control plugin graph.
|
||||
// SHOW_RESULT - bool @Optional
|
||||
// Whether to show the diffusion result at the current step. If this stream
|
||||
// is not empty, regardless show_every_n_iteration in the options.
|
||||
//
|
||||
// Outputs:
|
||||
// IMAGE - Image
|
||||
|
@ -218,6 +223,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// ITERATION - int @optional
|
||||
// The current iteration in the generating steps. The same as ITERATION
|
||||
// input.
|
||||
// SHOW_RESULT - bool @Optional
|
||||
// Whether to show the diffusion result at the current step. The same as
|
||||
// input SHOW_RESULT.
|
||||
class ImageGeneratorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
|
@ -239,6 +247,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
|||
condition_image = graph.In(kConditionImageTag).Cast<Image>();
|
||||
select_condition_type = graph.In(kSelectTag).Cast<int>();
|
||||
}
|
||||
std::optional<Source<bool>> show_result;
|
||||
if (HasInput(sc->OriginalNode(), kShowResultTag)) {
|
||||
show_result = graph.In(kShowResultTag).Cast<bool>();
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto outputs,
|
||||
BuildImageGeneratorGraph(
|
||||
|
@ -251,6 +263,7 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
|||
/*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(),
|
||||
/*condition_image*/ condition_image,
|
||||
/*select_condition_type*/ select_condition_type,
|
||||
/*show_result*/ show_result,
|
||||
},
|
||||
graph));
|
||||
outputs.generated_image >> graph.Out(kImageTag).Cast<Image>();
|
||||
|
@ -261,6 +274,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
|||
graph.In(kStepsTag) >> pass_through.In(1);
|
||||
pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)];
|
||||
pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)];
|
||||
if (HasOutput(sc->OriginalNode(), kShowResultTag)) {
|
||||
graph.In(kShowResultTag) >> pass_through.In(2);
|
||||
pass_through.Out(2) >> graph[Output<bool>::Optional(kShowResultTag)];
|
||||
}
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -299,15 +316,22 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
|||
inputs.steps >> stable_diff.In(kStepsTag);
|
||||
inputs.iteration >> stable_diff.In(kIterationTag);
|
||||
inputs.rand_seed >> stable_diff.In(kRandSeedTag);
|
||||
if (inputs.show_result.has_value()) {
|
||||
*inputs.show_result >> stable_diff.In(kShowResultTag);
|
||||
}
|
||||
mediapipe::StableDiffusionIterateCalculatorOptions& options =
|
||||
stable_diff
|
||||
.GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>();
|
||||
options.set_base_seed(0);
|
||||
options.set_output_image_height(kPluginsOutputSize);
|
||||
options.set_output_image_width(kPluginsOutputSize);
|
||||
options.set_file_folder(subgraph_options.text2image_model_directory());
|
||||
options.set_show_every_n_iteration(100);
|
||||
options.set_emit_empty_packet(true);
|
||||
if (subgraph_options.has_stable_diffusion_iterate_options()) {
|
||||
options = subgraph_options.stable_diffusion_iterate_options();
|
||||
} else {
|
||||
options.set_base_seed(0);
|
||||
options.set_output_image_height(kPluginsOutputSize);
|
||||
options.set_output_image_width(kPluginsOutputSize);
|
||||
options.set_file_folder(subgraph_options.text2image_model_directory());
|
||||
options.set_show_every_n_iteration(100);
|
||||
options.set_emit_empty_packet(true);
|
||||
}
|
||||
if (lora_resources.has_value()) {
|
||||
auto& lora_layer_weights_mapping =
|
||||
*options.mutable_lora_weights_layer_mapping();
|
||||
|
|
|
@ -48,5 +48,6 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
":control_plugin_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto3";
|
|||
package mediapipe.tasks.vision.image_generator.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/core/proto/external_file.proto";
|
||||
import "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto";
|
||||
import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
|
||||
|
@ -32,4 +33,7 @@ message ImageGeneratorGraphOptions {
|
|||
core.proto.ExternalFile lora_weights_file = 2;
|
||||
|
||||
repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3;
|
||||
|
||||
mediapipe.StableDiffusionIterateCalculatorOptions
|
||||
stable_diffusion_iterate_options = 4;
|
||||
}
|
||||
|
|
|
@ -67,6 +67,10 @@ _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||
|
|
|
@ -67,8 +67,8 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
|
||||
|
|
|
@ -19,6 +19,7 @@ import android.graphics.Bitmap;
|
|||
import android.util.Log;
|
||||
import androidx.annotation.Nullable;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.calculator.proto.StableDiffusionIterateCalculatorOptionsProto;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
|
@ -28,8 +29,6 @@ import com.google.mediapipe.framework.image.MPImage;
|
|||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
|
@ -64,17 +63,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
|
||||
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
|
||||
private static final String SELECT_STREAM_NAME = "select";
|
||||
private static final String SHOW_RESULT_STREAM_NAME = "show_result";
|
||||
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
|
||||
private static final int STEPS_OUT_STREAM_INDEX = 1;
|
||||
private static final int ITERATION_OUT_STREAM_INDEX = 2;
|
||||
private static final int SHOW_RESULT_OUT_STREAM_INDEX = 3;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
|
||||
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
|
||||
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
|
||||
private static final String TAG = "ImageGenerator";
|
||||
private static final int GENERATED_IMAGE_WIDTH = 512;
|
||||
private static final int GENERATED_IMAGE_HEIGHT = 512;
|
||||
private TaskRunner conditionImageGraphsContainerTaskRunner;
|
||||
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
|
||||
private boolean useConditionImage = false;
|
||||
private CachedInputs cachedInputs = new CachedInputs();
|
||||
private boolean inProcessing = false;
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
|
||||
|
@ -107,7 +112,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
"STEPS:" + STEPS_STREAM_NAME,
|
||||
"ITERATION:" + ITERATION_STREAM_NAME,
|
||||
"PROMPT:" + PROMPT_STREAM_NAME,
|
||||
"RAND_SEED:" + RAND_SEED_STREAM_NAME));
|
||||
"RAND_SEED:" + RAND_SEED_STREAM_NAME,
|
||||
"SHOW_RESULT:" + SHOW_RESULT_STREAM_NAME));
|
||||
final boolean useConditionImage = conditionOptions != null;
|
||||
if (useConditionImage) {
|
||||
inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
|
||||
|
@ -115,7 +121,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
generatorOptions.conditionOptions = Optional.of(conditionOptions);
|
||||
}
|
||||
List<String> outputStreams =
|
||||
Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out");
|
||||
Arrays.asList(
|
||||
"IMAGE:image_out",
|
||||
"STEPS:steps_out",
|
||||
"ITERATION:iteration_out",
|
||||
"SHOW_RESULT:show_result_out");
|
||||
|
||||
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
|
@ -125,16 +135,22 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
|
||||
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
|
||||
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
|
||||
Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps);
|
||||
if (iteration != steps - 1) {
|
||||
boolean showResult =
|
||||
PacketGetter.getBool(packets.get(SHOW_RESULT_OUT_STREAM_INDEX))
|
||||
|| iteration == steps - 1;
|
||||
Log.i(
|
||||
"ImageGenerator",
|
||||
"Iteration: " + iteration + ", Steps: " + steps + ", ShowResult: " + showResult);
|
||||
if (showResult) {
|
||||
Log.i("ImageGenerator", "processing generated image");
|
||||
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
|
||||
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
|
||||
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
|
||||
return ImageGeneratorResult.create(
|
||||
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
Log.i("ImageGenerator", "processing generated image");
|
||||
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
|
||||
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
|
||||
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
|
||||
return ImageGeneratorResult.create(
|
||||
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -143,16 +159,6 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
}
|
||||
});
|
||||
handler.setHandleTimestampBoundChanges(true);
|
||||
if (generatorOptions.resultListener().isPresent()) {
|
||||
ResultListener<ImageGeneratorResult, Void> resultListener =
|
||||
new ResultListener<ImageGeneratorResult, Void>() {
|
||||
@Override
|
||||
public void run(ImageGeneratorResult imageGeneratorResult, Void input) {
|
||||
generatorOptions.resultListener().get().run(imageGeneratorResult);
|
||||
}
|
||||
};
|
||||
handler.setResultListener(resultListener);
|
||||
}
|
||||
generatorOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
|
@ -229,11 +235,19 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
|
||||
* is created without condition options.
|
||||
*
|
||||
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
|
||||
* iterative API instead to fetch the intermediate results.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param iterations The total iterations to generate the image.
|
||||
* @param seed The random seed used during image generation.
|
||||
*/
|
||||
public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
|
||||
if (useConditionImage) {
|
||||
throw new IllegalArgumentException(
|
||||
"ImageGenerator is created with condition options. Must use the methods with condition "
|
||||
+ "options.");
|
||||
}
|
||||
return runIterations(prompt, iterations, seed, null, 0);
|
||||
}
|
||||
|
||||
|
@ -241,6 +255,9 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
* Generates an image based on the source image for iterations and the given random seed. Only
|
||||
* valid when the ImageGenerator is created with condition options.
|
||||
*
|
||||
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
|
||||
* iterative API instead to fetch the intermediate results.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param sourceConditionImage The source image used to create the condition image, which is used
|
||||
* as a guidance for the image generation.
|
||||
|
@ -255,6 +272,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
ConditionOptions.ConditionType conditionType,
|
||||
int iterations,
|
||||
int seed) {
|
||||
if (!useConditionImage) {
|
||||
throw new IllegalArgumentException(
|
||||
"ImageGenerator is created without condition options. Must use the methods without "
|
||||
+ "condition options.");
|
||||
}
|
||||
return runIterations(
|
||||
prompt,
|
||||
iterations,
|
||||
|
@ -263,6 +285,97 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
conditionTypeIndex.get(conditionType));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the inputs of the ImageGenerator. There is {@link setInputs} and {@link execute} method
|
||||
* pair for iterative usage. Users must call {@link setInputs} before {@link execute}. Only valid
|
||||
* when the ImageGenerator is created without condition options.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param iterations The total iterations to generate the image.
|
||||
* @param seed The random seed used during image generation.
|
||||
*/
|
||||
public void setInputs(String prompt, int iterations, int seed) {
|
||||
if (useConditionImage) {
|
||||
throw new IllegalArgumentException(
|
||||
"ImageGenerator is created with condition options. Must use the methods with condition "
|
||||
+ "options.");
|
||||
}
|
||||
cachedInputs = new CachedInputs();
|
||||
cachedInputs.prompt = prompt;
|
||||
cachedInputs.iterations = iterations;
|
||||
cachedInputs.seed = seed;
|
||||
cachedInputs.step = 0;
|
||||
inProcessing = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the inputs of the ImageGenerator. For iterative usage, use {@link setInputs} and {@link
|
||||
* execute} in pairs. Users must call {@link setInputs} before {@link execute}. Only valid when
|
||||
* the ImageGenerator is created with condition options.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param sourceConditionImage The source image used to create the condition image, which is used
|
||||
* as a guidance for the image generation.
|
||||
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
|
||||
* condition image.
|
||||
* @param iterations The total iterations to generate the image.
|
||||
* @param seed The random seed used during image generation.
|
||||
*/
|
||||
public void setInputs(
|
||||
String prompt,
|
||||
MPImage sourceConditionImage,
|
||||
ConditionOptions.ConditionType conditionType,
|
||||
int iterations,
|
||||
int seed) {
|
||||
if (!useConditionImage) {
|
||||
throw new IllegalArgumentException(
|
||||
"ImageGenerator is created without condition options. Must use the methods without "
|
||||
+ "condition options.");
|
||||
}
|
||||
cachedInputs = new CachedInputs();
|
||||
cachedInputs.prompt = prompt;
|
||||
cachedInputs.iterations = iterations;
|
||||
cachedInputs.seed = seed;
|
||||
cachedInputs.step = 0;
|
||||
cachedInputs.cachedConditionImage = createConditionImage(sourceConditionImage, conditionType);
|
||||
cachedInputs.conditionType = conditionType;
|
||||
inProcessing = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes one iteration of image generation. The method must be called {@code iterations} times
|
||||
* to generate the final image. Must call {@link setInputs} before calling this method.
|
||||
*
|
||||
* <p>This is an iterative API, which must be called iteratively.
|
||||
*
|
||||
* <p>This API is useful for showing the intermediate image generation results and the image
|
||||
* generation progress. Note that requesting the intermediate results will result in a larger
|
||||
* latency. Consider using the e2e API instead for latency consideration.
|
||||
*
|
||||
* <p>Example usage:
|
||||
*
|
||||
* <p>imageGenerator.setInputs(prompt, iterations, seed); for (int step = 0; step < iterations;
|
||||
* step++) { ImageGeneratorResult result = imageGenerator.execute(true); }
|
||||
*
|
||||
* @param showResult Whether to get the generated image result in the intermediate iterations. If
|
||||
* false, null is returned. The generated image result is always returned at the last
|
||||
* iteration, regardless of showResult value.
|
||||
*/
|
||||
@Nullable
|
||||
public ImageGeneratorResult execute(boolean showResult) {
|
||||
if (!inProcessing) {
|
||||
throw new IllegalArgumentException("Must call setInputs before execute.");
|
||||
}
|
||||
return runStep(
|
||||
cachedInputs.prompt,
|
||||
cachedInputs.iterations,
|
||||
cachedInputs.step++,
|
||||
cachedInputs.seed,
|
||||
cachedInputs.cachedConditionImage,
|
||||
useConditionImage ? conditionTypeIndex.get(cachedInputs.conditionType) : 0,
|
||||
showResult);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the condition image of specified condition type from the source image. Currently support
|
||||
* face landmarks, depth image and edge image as the condition image.
|
||||
|
@ -295,6 +408,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
|
||||
private ImageGeneratorResult runIterations(
|
||||
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
|
||||
if (inProcessing) {
|
||||
throw new IllegalArgumentException(
|
||||
"Iterative API was called previously. It is not allowed to called batch API during"
|
||||
+ "iterative processing.");
|
||||
}
|
||||
ImageGeneratorResult result = null;
|
||||
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
|
||||
for (int i = 0; i < steps; i++) {
|
||||
|
@ -318,11 +436,84 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
return result;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private ImageGeneratorResult runStep(
|
||||
String prompt,
|
||||
int iterations,
|
||||
int step,
|
||||
int seed,
|
||||
@Nullable MPImage conditionImage,
|
||||
int select,
|
||||
boolean showResult) {
|
||||
if (step == 0) {
|
||||
cachedInputs.cachedTimestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
if (step == 0 && useConditionImage) {
|
||||
inputPackets.put(
|
||||
CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage));
|
||||
inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select));
|
||||
}
|
||||
inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt));
|
||||
inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(iterations));
|
||||
inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(step));
|
||||
inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed));
|
||||
inputPackets.put(SHOW_RESULT_STREAM_NAME, runner.getPacketCreator().createBool(showResult));
|
||||
ImageGeneratorResult result =
|
||||
(ImageGeneratorResult) runner.process(inputPackets, cachedInputs.cachedTimestamp++);
|
||||
if (result != null && useConditionImage) {
|
||||
// Add condition image to the ImageGeneratorResult.
|
||||
result =
|
||||
ImageGeneratorResult.create(
|
||||
result.generatedImage(), conditionImage, result.timestampMs());
|
||||
}
|
||||
if (step == iterations - 1) {
|
||||
inProcessing = false;
|
||||
cachedInputs = new CachedInputs();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Closes and cleans up the task runners. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
conditionImageGraphsContainerTaskRunner.close();
|
||||
if (runner != null) {
|
||||
runner.close();
|
||||
}
|
||||
if (conditionImageGraphsContainerTaskRunner != null) {
|
||||
conditionImageGraphsContainerTaskRunner.close();
|
||||
}
|
||||
}
|
||||
|
||||
// Helper class to holder inputs to be checked with the inputs of next step.
|
||||
private static class CachedInputs {
|
||||
public CachedInputs() {
|
||||
this.prompt = "";
|
||||
this.iterations = 0;
|
||||
this.step = 0;
|
||||
this.seed = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String toString() {
|
||||
return "Prompt: "
|
||||
+ prompt
|
||||
+ ", Iterations: "
|
||||
+ iterations
|
||||
+ ", Step: "
|
||||
+ step
|
||||
+ ", Seed: "
|
||||
+ seed
|
||||
+ (conditionType == null ? "" : ", ConditionType: " + conditionType.name());
|
||||
}
|
||||
|
||||
public String prompt;
|
||||
public int iterations;
|
||||
public int step;
|
||||
public int seed;
|
||||
public ConditionOptions.ConditionType conditionType;
|
||||
public MPImage cachedConditionImage;
|
||||
public long cachedTimestamp;
|
||||
}
|
||||
|
||||
/** A container class for the condition image. */
|
||||
|
@ -343,15 +534,12 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Sets the text to image model directory storing the model weights. */
|
||||
public abstract Builder setText2ImageModelDirectory(String modelDirectory);
|
||||
/** Sets the image generator model directory storing the model weights. */
|
||||
public abstract Builder setImageGeneratorModelDirectory(String modelDirectory);
|
||||
|
||||
/** Sets the path to LoRA weights file. */
|
||||
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
|
||||
|
||||
public abstract Builder setResultListener(
|
||||
PureResultListener<ImageGeneratorResult> resultListener);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}}. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
|
@ -363,19 +551,17 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
}
|
||||
}
|
||||
|
||||
abstract String text2ImageModelDirectory();
|
||||
abstract String imageGeneratorModelDirectory();
|
||||
|
||||
abstract Optional<String> loraWeightsFilePath();
|
||||
|
||||
abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
private Optional<ConditionOptions> conditionOptions;
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
|
||||
.setText2ImageModelDirectory("");
|
||||
.setImageGeneratorModelDirectory("");
|
||||
}
|
||||
|
||||
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
|
||||
|
@ -393,13 +579,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory());
|
||||
taskOptionsBuilder.setText2ImageModelDirectory(imageGeneratorModelDirectory());
|
||||
if (loraWeightsFilePath().isPresent()) {
|
||||
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
|
||||
ExternalFileProto.ExternalFile.newBuilder();
|
||||
externalFileBuilder.setFileName(loraWeightsFilePath().get());
|
||||
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
|
||||
}
|
||||
taskOptionsBuilder.setStableDiffusionIterateOptions(
|
||||
StableDiffusionIterateCalculatorOptionsProto.StableDiffusionIterateCalculatorOptions
|
||||
.newBuilder()
|
||||
.setBaseSeed(0)
|
||||
.setFileFolder(imageGeneratorModelDirectory())
|
||||
.setOutputImageWidth(GENERATED_IMAGE_WIDTH)
|
||||
.setOutputImageHeight(GENERATED_IMAGE_HEIGHT)
|
||||
.setEmitEmptyPacket(true)
|
||||
.setShowEveryNIteration(100)
|
||||
.build());
|
||||
return Any.newBuilder()
|
||||
.setTypeUrl(
|
||||
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
|
||||
|
@ -465,7 +661,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(faceConditionOptions().get().baseOptions()))
|
||||
convertBaseOptionsToProto(
|
||||
faceConditionOptions().get().pluginModelBaseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
|
||||
|
@ -476,7 +673,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions()))
|
||||
convertBaseOptionsToProto(
|
||||
edgeConditionOptions().get().pluginModelBaseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
|
||||
|
@ -486,7 +684,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(depthConditionOptions().get().baseOptions()))
|
||||
convertBaseOptionsToProto(
|
||||
depthConditionOptions().get().pluginModelBaseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setDepthConditionTypeOptions(
|
||||
|
@ -510,11 +709,16 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */
|
||||
public abstract Builder setFaceLandmarkerOptions(
|
||||
FaceLandmarkerOptions faceLandmarkerOptions);
|
||||
/** Set base options for face landmarks model. */
|
||||
public abstract Builder setFaceModelBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/** Set minimum confidence score of face presence score in the face landmark detection. */
|
||||
public abstract Builder setMinFaceDetectionConfidence(float minFaceDetectionConfidence);
|
||||
|
||||
/** Set the face presence threshold */
|
||||
public abstract Builder setMinFacePresenceConfidence(float minFacePresenceConfidence);
|
||||
|
||||
abstract FaceConditionOptions autoBuild();
|
||||
|
||||
|
@ -524,23 +728,36 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
abstract BaseOptions pluginModelBaseOptions();
|
||||
|
||||
abstract FaceLandmarkerOptions faceLandmarkerOptions();
|
||||
abstract BaseOptions faceModelBaseOptions();
|
||||
|
||||
abstract float minFaceDetectionConfidence();
|
||||
|
||||
abstract float minFacePresenceConfidence();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder();
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder()
|
||||
.setMinFaceDetectionConfidence(0.5f)
|
||||
.setMinFacePresenceConfidence(0.5f);
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
|
||||
FaceLandmarkerOptions faceLandmarkerOptions =
|
||||
FaceLandmarkerOptions.builder()
|
||||
.setBaseOptions(faceModelBaseOptions())
|
||||
.setMinFaceDetectionConfidence(minFaceDetectionConfidence())
|
||||
.setMinFacePresenceConfidence(minFacePresenceConfidence())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setOutputFaceBlendshapes(false)
|
||||
.setOutputFacialTransformationMatrixes(false)
|
||||
.setNumFaces(1)
|
||||
.build();
|
||||
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
|
||||
.setFaceLandmarkerGraphOptions(
|
||||
FaceLandmarkerGraphOptions.newBuilder()
|
||||
.mergeFrom(
|
||||
faceLandmarkerOptions()
|
||||
.convertToCalculatorOptionsProto()
|
||||
.getExtension(FaceLandmarkerGraphOptions.ext))
|
||||
.build())
|
||||
faceLandmarkerOptions
|
||||
.convertToCalculatorOptionsProto()
|
||||
.getExtension(FaceLandmarkerGraphOptions.ext))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
@ -553,12 +770,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
/** Set the base options for the plugin model. */
|
||||
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/** {@link ImageSegmenterOptions} used to detect depth image from the source image. */
|
||||
public abstract Builder setImageSegmenterOptions(
|
||||
ImageSegmenterOptions imageSegmenterOptions);
|
||||
/** Set the base options for the depth model. */
|
||||
public abstract Builder setDepthModelBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
abstract DepthConditionOptions autoBuild();
|
||||
|
||||
|
@ -569,18 +785,25 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
abstract BaseOptions pluginModelBaseOptions();
|
||||
|
||||
abstract ImageSegmenterOptions imageSegmenterOptions();
|
||||
abstract BaseOptions depthModelBaseOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
|
||||
ImageSegmenterOptions imageSegmenterOptions =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(depthModelBaseOptions())
|
||||
.setOutputConfidenceMasks(true)
|
||||
.setOutputCategoryMask(false)
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
|
||||
.setImageSegmenterGraphOptions(
|
||||
imageSegmenterOptions()
|
||||
imageSegmenterOptions
|
||||
.convertToCalculatorOptionsProto()
|
||||
.getExtension(ImageSegmenterGraphOptions.ext))
|
||||
.build();
|
||||
|
@ -603,7 +826,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
public abstract static class Builder {
|
||||
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/** First threshold for the hysteresis procedure. */
|
||||
public abstract Builder setThreshold1(Float threshold1);
|
||||
|
@ -629,7 +852,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
|||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
abstract BaseOptions pluginModelBaseOptions();
|
||||
|
||||
abstract Float threshold1();
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user