Provide API/options to show intermediate results and generating progress for Java Image Generator.

PiperOrigin-RevId: 562014712
This commit is contained in:
MediaPipe Team 2023-09-01 12:04:02 -07:00 committed by Copybara-Service
parent ceb8cd3c78
commit 6c43d37e5a
6 changed files with 323 additions and 67 deletions

View File

@ -68,6 +68,7 @@ constexpr absl::string_view kRandSeedTag = "RAND_SEED";
constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS";
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
constexpr absl::string_view kSelectTag = "SELECT";
constexpr absl::string_view kShowResultTag = "SHOW_RESULT";
constexpr absl::string_view kMetadataFilename = "metadata";
constexpr absl::string_view kLoraRankStr = "lora_rank";
@ -78,6 +79,7 @@ struct ImageGeneratorInputs {
Source<int> rand_seed;
std::optional<Source<Image>> condition_image;
std::optional<Source<int>> select_condition_type;
std::optional<Source<bool>> show_result;
};
struct ImageGeneratorOutputs {
@ -209,6 +211,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// valid, if condtrol plugin graph options are set in the graph options.
// SELECT - int
// The index of the selected the control plugin graph.
// SHOW_RESULT - bool @Optional
// Whether to show the diffusion result at the current step. If this stream
// is not empty, regardless show_every_n_iteration in the options.
//
// Outputs:
// IMAGE - Image
@ -218,6 +223,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// ITERATION - int @optional
// The current iteration in the generating steps. The same as ITERATION
// input.
// SHOW_RESULT - bool @Optional
// Whether to show the diffusion result at the current step. The same as
// input SHOW_RESULT.
class ImageGeneratorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
@ -239,6 +247,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
condition_image = graph.In(kConditionImageTag).Cast<Image>();
select_condition_type = graph.In(kSelectTag).Cast<int>();
}
std::optional<Source<bool>> show_result;
if (HasInput(sc->OriginalNode(), kShowResultTag)) {
show_result = graph.In(kShowResultTag).Cast<bool>();
}
ASSIGN_OR_RETURN(
auto outputs,
BuildImageGeneratorGraph(
@ -251,6 +263,7 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
/*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(),
/*condition_image*/ condition_image,
/*select_condition_type*/ select_condition_type,
/*show_result*/ show_result,
},
graph));
outputs.generated_image >> graph.Out(kImageTag).Cast<Image>();
@ -261,6 +274,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
graph.In(kStepsTag) >> pass_through.In(1);
pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)];
pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)];
if (HasOutput(sc->OriginalNode(), kShowResultTag)) {
graph.In(kShowResultTag) >> pass_through.In(2);
pass_through.Out(2) >> graph[Output<bool>::Optional(kShowResultTag)];
}
return graph.GetConfig();
}
@ -299,15 +316,22 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
inputs.steps >> stable_diff.In(kStepsTag);
inputs.iteration >> stable_diff.In(kIterationTag);
inputs.rand_seed >> stable_diff.In(kRandSeedTag);
if (inputs.show_result.has_value()) {
*inputs.show_result >> stable_diff.In(kShowResultTag);
}
mediapipe::StableDiffusionIterateCalculatorOptions& options =
stable_diff
.GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>();
options.set_base_seed(0);
options.set_output_image_height(kPluginsOutputSize);
options.set_output_image_width(kPluginsOutputSize);
options.set_file_folder(subgraph_options.text2image_model_directory());
options.set_show_every_n_iteration(100);
options.set_emit_empty_packet(true);
if (subgraph_options.has_stable_diffusion_iterate_options()) {
options = subgraph_options.stable_diffusion_iterate_options();
} else {
options.set_base_seed(0);
options.set_output_image_height(kPluginsOutputSize);
options.set_output_image_width(kPluginsOutputSize);
options.set_file_folder(subgraph_options.text2image_model_directory());
options.set_show_every_n_iteration(100);
options.set_emit_empty_packet(true);
}
if (lora_resources.has_value()) {
auto& lora_layer_weights_mapping =
*options.mutable_lora_weights_layer_mapping();

View File

@ -48,5 +48,6 @@ mediapipe_proto_library(
deps = [
":control_plugin_graph_options_proto",
"//mediapipe/tasks/cc/core/proto:external_file_proto",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_proto",
],
)

View File

@ -18,6 +18,7 @@ syntax = "proto3";
package mediapipe.tasks.vision.image_generator.proto;
import "mediapipe/tasks/cc/core/proto/external_file.proto";
import "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto";
import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
@ -32,4 +33,7 @@ message ImageGeneratorGraphOptions {
core.proto.ExternalFile lora_weights_file = 2;
repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3;
mediapipe.StableDiffusionIterateCalculatorOptions
stable_diffusion_iterate_options = 4;
}

View File

@ -67,6 +67,10 @@ _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",

View File

@ -67,8 +67,8 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",

View File

@ -19,6 +19,7 @@ import android.graphics.Bitmap;
import android.util.Log;
import androidx.annotation.Nullable;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.calculator.proto.StableDiffusionIterateCalculatorOptionsProto;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet;
@ -28,8 +29,6 @@ import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskResult;
@ -64,17 +63,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
private static final String SELECT_STREAM_NAME = "select";
private static final String SHOW_RESULT_STREAM_NAME = "show_result";
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
private static final int STEPS_OUT_STREAM_INDEX = 1;
private static final int ITERATION_OUT_STREAM_INDEX = 2;
private static final int SHOW_RESULT_OUT_STREAM_INDEX = 3;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
private static final String TAG = "ImageGenerator";
private static final int GENERATED_IMAGE_WIDTH = 512;
private static final int GENERATED_IMAGE_HEIGHT = 512;
private TaskRunner conditionImageGraphsContainerTaskRunner;
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
private boolean useConditionImage = false;
private CachedInputs cachedInputs = new CachedInputs();
private boolean inProcessing = false;
/**
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
@ -107,7 +112,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
"STEPS:" + STEPS_STREAM_NAME,
"ITERATION:" + ITERATION_STREAM_NAME,
"PROMPT:" + PROMPT_STREAM_NAME,
"RAND_SEED:" + RAND_SEED_STREAM_NAME));
"RAND_SEED:" + RAND_SEED_STREAM_NAME,
"SHOW_RESULT:" + SHOW_RESULT_STREAM_NAME));
final boolean useConditionImage = conditionOptions != null;
if (useConditionImage) {
inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
@ -115,7 +121,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
generatorOptions.conditionOptions = Optional.of(conditionOptions);
}
List<String> outputStreams =
Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out");
Arrays.asList(
"IMAGE:image_out",
"STEPS:steps_out",
"ITERATION:iteration_out",
"SHOW_RESULT:show_result_out");
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
@ -125,16 +135,22 @@ public final class ImageGenerator extends BaseVisionTaskApi {
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps);
if (iteration != steps - 1) {
boolean showResult =
PacketGetter.getBool(packets.get(SHOW_RESULT_OUT_STREAM_INDEX))
|| iteration == steps - 1;
Log.i(
"ImageGenerator",
"Iteration: " + iteration + ", Steps: " + steps + ", ShowResult: " + showResult);
if (showResult) {
Log.i("ImageGenerator", "processing generated image");
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
return ImageGeneratorResult.create(
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
} else {
return null;
}
Log.i("ImageGenerator", "processing generated image");
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
return ImageGeneratorResult.create(
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
}
@Override
@ -143,16 +159,6 @@ public final class ImageGenerator extends BaseVisionTaskApi {
}
});
handler.setHandleTimestampBoundChanges(true);
if (generatorOptions.resultListener().isPresent()) {
ResultListener<ImageGeneratorResult, Void> resultListener =
new ResultListener<ImageGeneratorResult, Void>() {
@Override
public void run(ImageGeneratorResult imageGeneratorResult, Void input) {
generatorOptions.resultListener().get().run(imageGeneratorResult);
}
};
handler.setResultListener(resultListener);
}
generatorOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
@ -229,11 +235,19 @@ public final class ImageGenerator extends BaseVisionTaskApi {
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
* is created without condition options.
*
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
* iterative API instead to fetch the intermediate results.
*
* @param prompt The text prompt describing the image to be generated.
* @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation.
*/
public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
if (useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created with condition options. Must use the methods with condition "
+ "options.");
}
return runIterations(prompt, iterations, seed, null, 0);
}
@ -241,6 +255,9 @@ public final class ImageGenerator extends BaseVisionTaskApi {
* Generates an image based on the source image for iterations and the given random seed. Only
* valid when the ImageGenerator is created with condition options.
*
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
* iterative API instead to fetch the intermediate results.
*
* @param prompt The text prompt describing the image to be generated.
* @param sourceConditionImage The source image used to create the condition image, which is used
* as a guidance for the image generation.
@ -255,6 +272,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
ConditionOptions.ConditionType conditionType,
int iterations,
int seed) {
if (!useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created without condition options. Must use the methods without "
+ "condition options.");
}
return runIterations(
prompt,
iterations,
@ -263,6 +285,97 @@ public final class ImageGenerator extends BaseVisionTaskApi {
conditionTypeIndex.get(conditionType));
}
/**
* Sets the inputs of the ImageGenerator. There is {@link setInputs} and {@link execute} method
* pair for iterative usage. Users must call {@link setInputs} before {@link execute}. Only valid
* when the ImageGenerator is created without condition options.
*
* @param prompt The text prompt describing the image to be generated.
* @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation.
*/
public void setInputs(String prompt, int iterations, int seed) {
if (useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created with condition options. Must use the methods with condition "
+ "options.");
}
cachedInputs = new CachedInputs();
cachedInputs.prompt = prompt;
cachedInputs.iterations = iterations;
cachedInputs.seed = seed;
cachedInputs.step = 0;
inProcessing = true;
}
/**
* Sets the inputs of the ImageGenerator. For iterative usage, use {@link setInputs} and {@link
* execute} in pairs. Users must call {@link setInputs} before {@link execute}. Only valid when
* the ImageGenerator is created with condition options.
*
* @param prompt The text prompt describing the image to be generated.
* @param sourceConditionImage The source image used to create the condition image, which is used
* as a guidance for the image generation.
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
* condition image.
* @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation.
*/
public void setInputs(
String prompt,
MPImage sourceConditionImage,
ConditionOptions.ConditionType conditionType,
int iterations,
int seed) {
if (!useConditionImage) {
throw new IllegalArgumentException(
"ImageGenerator is created without condition options. Must use the methods without "
+ "condition options.");
}
cachedInputs = new CachedInputs();
cachedInputs.prompt = prompt;
cachedInputs.iterations = iterations;
cachedInputs.seed = seed;
cachedInputs.step = 0;
cachedInputs.cachedConditionImage = createConditionImage(sourceConditionImage, conditionType);
cachedInputs.conditionType = conditionType;
inProcessing = true;
}
/**
* Executes one iteration of image generation. The method must be called {@code iterations} times
* to generate the final image. Must call {@link setInputs} before calling this method.
*
* <p>This is an iterative API, which must be called iteratively.
*
* <p>This API is useful for showing the intermediate image generation results and the image
* generation progress. Note that requesting the intermediate results will result in a larger
* latency. Consider using the e2e API instead for latency consideration.
*
* <p>Example usage:
*
* <p>imageGenerator.setInputs(prompt, iterations, seed); for (int step = 0; step < iterations;
* step++) { ImageGeneratorResult result = imageGenerator.execute(true); }
*
* @param showResult Whether to get the generated image result in the intermediate iterations. If
* false, null is returned. The generated image result is always returned at the last
* iteration, regardless of showResult value.
*/
@Nullable
public ImageGeneratorResult execute(boolean showResult) {
if (!inProcessing) {
throw new IllegalArgumentException("Must call setInputs before execute.");
}
return runStep(
cachedInputs.prompt,
cachedInputs.iterations,
cachedInputs.step++,
cachedInputs.seed,
cachedInputs.cachedConditionImage,
useConditionImage ? conditionTypeIndex.get(cachedInputs.conditionType) : 0,
showResult);
}
/**
* Create the condition image of specified condition type from the source image. Currently support
* face landmarks, depth image and edge image as the condition image.
@ -295,6 +408,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
private ImageGeneratorResult runIterations(
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
if (inProcessing) {
throw new IllegalArgumentException(
"Iterative API was called previously. It is not allowed to called batch API during"
+ "iterative processing.");
}
ImageGeneratorResult result = null;
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
for (int i = 0; i < steps; i++) {
@ -318,11 +436,84 @@ public final class ImageGenerator extends BaseVisionTaskApi {
return result;
}
@Nullable
private ImageGeneratorResult runStep(
String prompt,
int iterations,
int step,
int seed,
@Nullable MPImage conditionImage,
int select,
boolean showResult) {
if (step == 0) {
cachedInputs.cachedTimestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
}
Map<String, Packet> inputPackets = new HashMap<>();
if (step == 0 && useConditionImage) {
inputPackets.put(
CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage));
inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select));
}
inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt));
inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(iterations));
inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(step));
inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed));
inputPackets.put(SHOW_RESULT_STREAM_NAME, runner.getPacketCreator().createBool(showResult));
ImageGeneratorResult result =
(ImageGeneratorResult) runner.process(inputPackets, cachedInputs.cachedTimestamp++);
if (result != null && useConditionImage) {
// Add condition image to the ImageGeneratorResult.
result =
ImageGeneratorResult.create(
result.generatedImage(), conditionImage, result.timestampMs());
}
if (step == iterations - 1) {
inProcessing = false;
cachedInputs = new CachedInputs();
}
return result;
}
/** Closes and cleans up the task runners. */
@Override
public void close() {
runner.close();
conditionImageGraphsContainerTaskRunner.close();
if (runner != null) {
runner.close();
}
if (conditionImageGraphsContainerTaskRunner != null) {
conditionImageGraphsContainerTaskRunner.close();
}
}
// Helper class to holder inputs to be checked with the inputs of next step.
private static class CachedInputs {
public CachedInputs() {
this.prompt = "";
this.iterations = 0;
this.step = 0;
this.seed = 0;
}
@Override
public final String toString() {
return "Prompt: "
+ prompt
+ ", Iterations: "
+ iterations
+ ", Step: "
+ step
+ ", Seed: "
+ seed
+ (conditionType == null ? "" : ", ConditionType: " + conditionType.name());
}
public String prompt;
public int iterations;
public int step;
public int seed;
public ConditionOptions.ConditionType conditionType;
public MPImage cachedConditionImage;
public long cachedTimestamp;
}
/** A container class for the condition image. */
@ -343,15 +534,12 @@ public final class ImageGenerator extends BaseVisionTaskApi {
@AutoValue.Builder
public abstract static class Builder {
/** Sets the text to image model directory storing the model weights. */
public abstract Builder setText2ImageModelDirectory(String modelDirectory);
/** Sets the image generator model directory storing the model weights. */
public abstract Builder setImageGeneratorModelDirectory(String modelDirectory);
/** Sets the path to LoRA weights file. */
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
public abstract Builder setResultListener(
PureResultListener<ImageGeneratorResult> resultListener);
/** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value);
@ -363,19 +551,17 @@ public final class ImageGenerator extends BaseVisionTaskApi {
}
}
abstract String text2ImageModelDirectory();
abstract String imageGeneratorModelDirectory();
abstract Optional<String> loraWeightsFilePath();
abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener();
abstract Optional<ErrorListener> errorListener();
private Optional<ConditionOptions> conditionOptions;
public static Builder builder() {
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
.setText2ImageModelDirectory("");
.setImageGeneratorModelDirectory("");
}
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
@ -393,13 +579,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
e.printStackTrace();
}
}
taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory());
taskOptionsBuilder.setText2ImageModelDirectory(imageGeneratorModelDirectory());
if (loraWeightsFilePath().isPresent()) {
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
ExternalFileProto.ExternalFile.newBuilder();
externalFileBuilder.setFileName(loraWeightsFilePath().get());
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
}
taskOptionsBuilder.setStableDiffusionIterateOptions(
StableDiffusionIterateCalculatorOptionsProto.StableDiffusionIterateCalculatorOptions
.newBuilder()
.setBaseSeed(0)
.setFileFolder(imageGeneratorModelDirectory())
.setOutputImageWidth(GENERATED_IMAGE_WIDTH)
.setOutputImageHeight(GENERATED_IMAGE_HEIGHT)
.setEmitEmptyPacket(true)
.setShowEveryNIteration(100)
.build());
return Any.newBuilder()
.setTypeUrl(
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
@ -465,7 +661,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(faceConditionOptions().get().baseOptions()))
convertBaseOptionsToProto(
faceConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
@ -476,7 +673,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions()))
convertBaseOptionsToProto(
edgeConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
@ -486,7 +684,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(depthConditionOptions().get().baseOptions()))
convertBaseOptionsToProto(
depthConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setDepthConditionTypeOptions(
@ -510,11 +709,16 @@ public final class ImageGenerator extends BaseVisionTaskApi {
@AutoValue.Builder
public abstract static class Builder {
/** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
/* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */
public abstract Builder setFaceLandmarkerOptions(
FaceLandmarkerOptions faceLandmarkerOptions);
/** Set base options for face landmarks model. */
public abstract Builder setFaceModelBaseOptions(BaseOptions baseOptions);
/** Set minimum confidence score of face presence score in the face landmark detection. */
public abstract Builder setMinFaceDetectionConfidence(float minFaceDetectionConfidence);
/** Set the face presence threshold */
public abstract Builder setMinFacePresenceConfidence(float minFacePresenceConfidence);
abstract FaceConditionOptions autoBuild();
@ -524,23 +728,36 @@ public final class ImageGenerator extends BaseVisionTaskApi {
}
}
abstract BaseOptions baseOptions();
abstract BaseOptions pluginModelBaseOptions();
abstract FaceLandmarkerOptions faceLandmarkerOptions();
abstract BaseOptions faceModelBaseOptions();
abstract float minFaceDetectionConfidence();
abstract float minFacePresenceConfidence();
public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder();
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder()
.setMinFaceDetectionConfidence(0.5f)
.setMinFacePresenceConfidence(0.5f);
}
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
FaceLandmarkerOptions faceLandmarkerOptions =
FaceLandmarkerOptions.builder()
.setBaseOptions(faceModelBaseOptions())
.setMinFaceDetectionConfidence(minFaceDetectionConfidence())
.setMinFacePresenceConfidence(minFacePresenceConfidence())
.setRunningMode(RunningMode.IMAGE)
.setOutputFaceBlendshapes(false)
.setOutputFacialTransformationMatrixes(false)
.setNumFaces(1)
.build();
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
.setFaceLandmarkerGraphOptions(
FaceLandmarkerGraphOptions.newBuilder()
.mergeFrom(
faceLandmarkerOptions()
.convertToCalculatorOptionsProto()
.getExtension(FaceLandmarkerGraphOptions.ext))
.build())
faceLandmarkerOptions
.convertToCalculatorOptionsProto()
.getExtension(FaceLandmarkerGraphOptions.ext))
.build();
}
}
@ -553,12 +770,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
@AutoValue.Builder
public abstract static class Builder {
/** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/** Set the base options for the plugin model. */
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
/** {@link ImageSegmenterOptions} used to detect depth image from the source image. */
public abstract Builder setImageSegmenterOptions(
ImageSegmenterOptions imageSegmenterOptions);
/** Set the base options for the depth model. */
public abstract Builder setDepthModelBaseOptions(BaseOptions baseOptions);
abstract DepthConditionOptions autoBuild();
@ -569,18 +785,25 @@ public final class ImageGenerator extends BaseVisionTaskApi {
}
}
abstract BaseOptions baseOptions();
abstract BaseOptions pluginModelBaseOptions();
abstract ImageSegmenterOptions imageSegmenterOptions();
abstract BaseOptions depthModelBaseOptions();
public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
}
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
ImageSegmenterOptions imageSegmenterOptions =
ImageSegmenterOptions.builder()
.setBaseOptions(depthModelBaseOptions())
.setOutputConfidenceMasks(true)
.setOutputCategoryMask(false)
.setRunningMode(RunningMode.IMAGE)
.build();
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
.setImageSegmenterGraphOptions(
imageSegmenterOptions()
imageSegmenterOptions
.convertToCalculatorOptionsProto()
.getExtension(ImageSegmenterGraphOptions.ext))
.build();
@ -603,7 +826,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
public abstract static class Builder {
/** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
/** First threshold for the hysteresis procedure. */
public abstract Builder setThreshold1(Float threshold1);
@ -629,7 +852,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
}
}
abstract BaseOptions baseOptions();
abstract BaseOptions pluginModelBaseOptions();
abstract Float threshold1();