Provide API/options to show intermediate results and generating progress for Java Image Generator.
PiperOrigin-RevId: 562014712
This commit is contained in:
parent
ceb8cd3c78
commit
6c43d37e5a
|
@ -68,6 +68,7 @@ constexpr absl::string_view kRandSeedTag = "RAND_SEED";
|
||||||
constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS";
|
constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS";
|
||||||
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
|
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
|
||||||
constexpr absl::string_view kSelectTag = "SELECT";
|
constexpr absl::string_view kSelectTag = "SELECT";
|
||||||
|
constexpr absl::string_view kShowResultTag = "SHOW_RESULT";
|
||||||
constexpr absl::string_view kMetadataFilename = "metadata";
|
constexpr absl::string_view kMetadataFilename = "metadata";
|
||||||
constexpr absl::string_view kLoraRankStr = "lora_rank";
|
constexpr absl::string_view kLoraRankStr = "lora_rank";
|
||||||
|
|
||||||
|
@ -78,6 +79,7 @@ struct ImageGeneratorInputs {
|
||||||
Source<int> rand_seed;
|
Source<int> rand_seed;
|
||||||
std::optional<Source<Image>> condition_image;
|
std::optional<Source<Image>> condition_image;
|
||||||
std::optional<Source<int>> select_condition_type;
|
std::optional<Source<int>> select_condition_type;
|
||||||
|
std::optional<Source<bool>> show_result;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ImageGeneratorOutputs {
|
struct ImageGeneratorOutputs {
|
||||||
|
@ -209,6 +211,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
||||||
// valid, if condtrol plugin graph options are set in the graph options.
|
// valid, if condtrol plugin graph options are set in the graph options.
|
||||||
// SELECT - int
|
// SELECT - int
|
||||||
// The index of the selected the control plugin graph.
|
// The index of the selected the control plugin graph.
|
||||||
|
// SHOW_RESULT - bool @Optional
|
||||||
|
// Whether to show the diffusion result at the current step. If this stream
|
||||||
|
// is not empty, regardless show_every_n_iteration in the options.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// IMAGE - Image
|
// IMAGE - Image
|
||||||
|
@ -218,6 +223,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
||||||
// ITERATION - int @optional
|
// ITERATION - int @optional
|
||||||
// The current iteration in the generating steps. The same as ITERATION
|
// The current iteration in the generating steps. The same as ITERATION
|
||||||
// input.
|
// input.
|
||||||
|
// SHOW_RESULT - bool @Optional
|
||||||
|
// Whether to show the diffusion result at the current step. The same as
|
||||||
|
// input SHOW_RESULT.
|
||||||
class ImageGeneratorGraph : public core::ModelTaskGraph {
|
class ImageGeneratorGraph : public core::ModelTaskGraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
|
@ -239,6 +247,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
||||||
condition_image = graph.In(kConditionImageTag).Cast<Image>();
|
condition_image = graph.In(kConditionImageTag).Cast<Image>();
|
||||||
select_condition_type = graph.In(kSelectTag).Cast<int>();
|
select_condition_type = graph.In(kSelectTag).Cast<int>();
|
||||||
}
|
}
|
||||||
|
std::optional<Source<bool>> show_result;
|
||||||
|
if (HasInput(sc->OriginalNode(), kShowResultTag)) {
|
||||||
|
show_result = graph.In(kShowResultTag).Cast<bool>();
|
||||||
|
}
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto outputs,
|
auto outputs,
|
||||||
BuildImageGeneratorGraph(
|
BuildImageGeneratorGraph(
|
||||||
|
@ -251,6 +263,7 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
||||||
/*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(),
|
/*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(),
|
||||||
/*condition_image*/ condition_image,
|
/*condition_image*/ condition_image,
|
||||||
/*select_condition_type*/ select_condition_type,
|
/*select_condition_type*/ select_condition_type,
|
||||||
|
/*show_result*/ show_result,
|
||||||
},
|
},
|
||||||
graph));
|
graph));
|
||||||
outputs.generated_image >> graph.Out(kImageTag).Cast<Image>();
|
outputs.generated_image >> graph.Out(kImageTag).Cast<Image>();
|
||||||
|
@ -261,6 +274,10 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
||||||
graph.In(kStepsTag) >> pass_through.In(1);
|
graph.In(kStepsTag) >> pass_through.In(1);
|
||||||
pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)];
|
pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)];
|
||||||
pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)];
|
pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)];
|
||||||
|
if (HasOutput(sc->OriginalNode(), kShowResultTag)) {
|
||||||
|
graph.In(kShowResultTag) >> pass_through.In(2);
|
||||||
|
pass_through.Out(2) >> graph[Output<bool>::Optional(kShowResultTag)];
|
||||||
|
}
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,15 +316,22 @@ class ImageGeneratorGraph : public core::ModelTaskGraph {
|
||||||
inputs.steps >> stable_diff.In(kStepsTag);
|
inputs.steps >> stable_diff.In(kStepsTag);
|
||||||
inputs.iteration >> stable_diff.In(kIterationTag);
|
inputs.iteration >> stable_diff.In(kIterationTag);
|
||||||
inputs.rand_seed >> stable_diff.In(kRandSeedTag);
|
inputs.rand_seed >> stable_diff.In(kRandSeedTag);
|
||||||
|
if (inputs.show_result.has_value()) {
|
||||||
|
*inputs.show_result >> stable_diff.In(kShowResultTag);
|
||||||
|
}
|
||||||
mediapipe::StableDiffusionIterateCalculatorOptions& options =
|
mediapipe::StableDiffusionIterateCalculatorOptions& options =
|
||||||
stable_diff
|
stable_diff
|
||||||
.GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>();
|
.GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>();
|
||||||
|
if (subgraph_options.has_stable_diffusion_iterate_options()) {
|
||||||
|
options = subgraph_options.stable_diffusion_iterate_options();
|
||||||
|
} else {
|
||||||
options.set_base_seed(0);
|
options.set_base_seed(0);
|
||||||
options.set_output_image_height(kPluginsOutputSize);
|
options.set_output_image_height(kPluginsOutputSize);
|
||||||
options.set_output_image_width(kPluginsOutputSize);
|
options.set_output_image_width(kPluginsOutputSize);
|
||||||
options.set_file_folder(subgraph_options.text2image_model_directory());
|
options.set_file_folder(subgraph_options.text2image_model_directory());
|
||||||
options.set_show_every_n_iteration(100);
|
options.set_show_every_n_iteration(100);
|
||||||
options.set_emit_empty_packet(true);
|
options.set_emit_empty_packet(true);
|
||||||
|
}
|
||||||
if (lora_resources.has_value()) {
|
if (lora_resources.has_value()) {
|
||||||
auto& lora_layer_weights_mapping =
|
auto& lora_layer_weights_mapping =
|
||||||
*options.mutable_lora_weights_layer_mapping();
|
*options.mutable_lora_weights_layer_mapping();
|
||||||
|
|
|
@ -48,5 +48,6 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
":control_plugin_graph_options_proto",
|
":control_plugin_graph_options_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_proto",
|
"//mediapipe/tasks/cc/core/proto:external_file_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto3";
|
||||||
package mediapipe.tasks.vision.image_generator.proto;
|
package mediapipe.tasks.vision.image_generator.proto;
|
||||||
|
|
||||||
import "mediapipe/tasks/cc/core/proto/external_file.proto";
|
import "mediapipe/tasks/cc/core/proto/external_file.proto";
|
||||||
|
import "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto";
|
||||||
import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
|
||||||
|
@ -32,4 +33,7 @@ message ImageGeneratorGraphOptions {
|
||||||
core.proto.ExternalFile lora_weights_file = 2;
|
core.proto.ExternalFile lora_weights_file = 2;
|
||||||
|
|
||||||
repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3;
|
repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3;
|
||||||
|
|
||||||
|
mediapipe.StableDiffusionIterateCalculatorOptions
|
||||||
|
stable_diffusion_iterate_options = 4;
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,6 +67,10 @@ _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||||
|
|
|
@ -67,8 +67,8 @@ android_library(
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
|
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
|
||||||
|
|
|
@ -19,6 +19,7 @@ import android.graphics.Bitmap;
|
||||||
import android.util.Log;
|
import android.util.Log;
|
||||||
import androidx.annotation.Nullable;
|
import androidx.annotation.Nullable;
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.calculator.proto.StableDiffusionIterateCalculatorOptionsProto;
|
||||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||||
import com.google.mediapipe.framework.Packet;
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
@ -28,8 +29,6 @@ import com.google.mediapipe.framework.image.MPImage;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
|
||||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
import com.google.mediapipe.tasks.core.TaskResult;
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
@ -64,17 +63,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
|
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
|
||||||
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
|
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
|
||||||
private static final String SELECT_STREAM_NAME = "select";
|
private static final String SELECT_STREAM_NAME = "select";
|
||||||
|
private static final String SHOW_RESULT_STREAM_NAME = "show_result";
|
||||||
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
|
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
|
||||||
private static final int STEPS_OUT_STREAM_INDEX = 1;
|
private static final int STEPS_OUT_STREAM_INDEX = 1;
|
||||||
private static final int ITERATION_OUT_STREAM_INDEX = 2;
|
private static final int ITERATION_OUT_STREAM_INDEX = 2;
|
||||||
|
private static final int SHOW_RESULT_OUT_STREAM_INDEX = 3;
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
|
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
|
||||||
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
|
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
|
||||||
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
|
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
|
||||||
private static final String TAG = "ImageGenerator";
|
private static final String TAG = "ImageGenerator";
|
||||||
|
private static final int GENERATED_IMAGE_WIDTH = 512;
|
||||||
|
private static final int GENERATED_IMAGE_HEIGHT = 512;
|
||||||
private TaskRunner conditionImageGraphsContainerTaskRunner;
|
private TaskRunner conditionImageGraphsContainerTaskRunner;
|
||||||
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
|
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
|
||||||
private boolean useConditionImage = false;
|
private boolean useConditionImage = false;
|
||||||
|
private CachedInputs cachedInputs = new CachedInputs();
|
||||||
|
private boolean inProcessing = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
|
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
|
||||||
|
@ -107,7 +112,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
"STEPS:" + STEPS_STREAM_NAME,
|
"STEPS:" + STEPS_STREAM_NAME,
|
||||||
"ITERATION:" + ITERATION_STREAM_NAME,
|
"ITERATION:" + ITERATION_STREAM_NAME,
|
||||||
"PROMPT:" + PROMPT_STREAM_NAME,
|
"PROMPT:" + PROMPT_STREAM_NAME,
|
||||||
"RAND_SEED:" + RAND_SEED_STREAM_NAME));
|
"RAND_SEED:" + RAND_SEED_STREAM_NAME,
|
||||||
|
"SHOW_RESULT:" + SHOW_RESULT_STREAM_NAME));
|
||||||
final boolean useConditionImage = conditionOptions != null;
|
final boolean useConditionImage = conditionOptions != null;
|
||||||
if (useConditionImage) {
|
if (useConditionImage) {
|
||||||
inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
|
inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
|
||||||
|
@ -115,7 +121,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
generatorOptions.conditionOptions = Optional.of(conditionOptions);
|
generatorOptions.conditionOptions = Optional.of(conditionOptions);
|
||||||
}
|
}
|
||||||
List<String> outputStreams =
|
List<String> outputStreams =
|
||||||
Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out");
|
Arrays.asList(
|
||||||
|
"IMAGE:image_out",
|
||||||
|
"STEPS:steps_out",
|
||||||
|
"ITERATION:iteration_out",
|
||||||
|
"SHOW_RESULT:show_result_out");
|
||||||
|
|
||||||
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
|
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
|
||||||
handler.setOutputPacketConverter(
|
handler.setOutputPacketConverter(
|
||||||
|
@ -125,16 +135,22 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
|
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
|
||||||
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
|
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
|
||||||
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
|
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
|
||||||
Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps);
|
boolean showResult =
|
||||||
if (iteration != steps - 1) {
|
PacketGetter.getBool(packets.get(SHOW_RESULT_OUT_STREAM_INDEX))
|
||||||
return null;
|
|| iteration == steps - 1;
|
||||||
}
|
Log.i(
|
||||||
|
"ImageGenerator",
|
||||||
|
"Iteration: " + iteration + ", Steps: " + steps + ", ShowResult: " + showResult);
|
||||||
|
if (showResult) {
|
||||||
Log.i("ImageGenerator", "processing generated image");
|
Log.i("ImageGenerator", "processing generated image");
|
||||||
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
|
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
|
||||||
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
|
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
|
||||||
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
|
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
|
||||||
return ImageGeneratorResult.create(
|
return ImageGeneratorResult.create(
|
||||||
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -143,16 +159,6 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
handler.setHandleTimestampBoundChanges(true);
|
handler.setHandleTimestampBoundChanges(true);
|
||||||
if (generatorOptions.resultListener().isPresent()) {
|
|
||||||
ResultListener<ImageGeneratorResult, Void> resultListener =
|
|
||||||
new ResultListener<ImageGeneratorResult, Void>() {
|
|
||||||
@Override
|
|
||||||
public void run(ImageGeneratorResult imageGeneratorResult, Void input) {
|
|
||||||
generatorOptions.resultListener().get().run(imageGeneratorResult);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
handler.setResultListener(resultListener);
|
|
||||||
}
|
|
||||||
generatorOptions.errorListener().ifPresent(handler::setErrorListener);
|
generatorOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||||
TaskRunner runner =
|
TaskRunner runner =
|
||||||
TaskRunner.create(
|
TaskRunner.create(
|
||||||
|
@ -229,11 +235,19 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
|
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
|
||||||
* is created without condition options.
|
* is created without condition options.
|
||||||
*
|
*
|
||||||
|
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
|
||||||
|
* iterative API instead to fetch the intermediate results.
|
||||||
|
*
|
||||||
* @param prompt The text prompt describing the image to be generated.
|
* @param prompt The text prompt describing the image to be generated.
|
||||||
* @param iterations The total iterations to generate the image.
|
* @param iterations The total iterations to generate the image.
|
||||||
* @param seed The random seed used during image generation.
|
* @param seed The random seed used during image generation.
|
||||||
*/
|
*/
|
||||||
public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
|
public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
|
||||||
|
if (useConditionImage) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"ImageGenerator is created with condition options. Must use the methods with condition "
|
||||||
|
+ "options.");
|
||||||
|
}
|
||||||
return runIterations(prompt, iterations, seed, null, 0);
|
return runIterations(prompt, iterations, seed, null, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,6 +255,9 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
* Generates an image based on the source image for iterations and the given random seed. Only
|
* Generates an image based on the source image for iterations and the given random seed. Only
|
||||||
* valid when the ImageGenerator is created with condition options.
|
* valid when the ImageGenerator is created with condition options.
|
||||||
*
|
*
|
||||||
|
* <p>This is an e2e API, which runs {@code iterations} to generate an image. Consider using the
|
||||||
|
* iterative API instead to fetch the intermediate results.
|
||||||
|
*
|
||||||
* @param prompt The text prompt describing the image to be generated.
|
* @param prompt The text prompt describing the image to be generated.
|
||||||
* @param sourceConditionImage The source image used to create the condition image, which is used
|
* @param sourceConditionImage The source image used to create the condition image, which is used
|
||||||
* as a guidance for the image generation.
|
* as a guidance for the image generation.
|
||||||
|
@ -255,6 +272,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
ConditionOptions.ConditionType conditionType,
|
ConditionOptions.ConditionType conditionType,
|
||||||
int iterations,
|
int iterations,
|
||||||
int seed) {
|
int seed) {
|
||||||
|
if (!useConditionImage) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"ImageGenerator is created without condition options. Must use the methods without "
|
||||||
|
+ "condition options.");
|
||||||
|
}
|
||||||
return runIterations(
|
return runIterations(
|
||||||
prompt,
|
prompt,
|
||||||
iterations,
|
iterations,
|
||||||
|
@ -263,6 +285,97 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
conditionTypeIndex.get(conditionType));
|
conditionTypeIndex.get(conditionType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the inputs of the ImageGenerator. There is {@link setInputs} and {@link execute} method
|
||||||
|
* pair for iterative usage. Users must call {@link setInputs} before {@link execute}. Only valid
|
||||||
|
* when the ImageGenerator is created without condition options.
|
||||||
|
*
|
||||||
|
* @param prompt The text prompt describing the image to be generated.
|
||||||
|
* @param iterations The total iterations to generate the image.
|
||||||
|
* @param seed The random seed used during image generation.
|
||||||
|
*/
|
||||||
|
public void setInputs(String prompt, int iterations, int seed) {
|
||||||
|
if (useConditionImage) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"ImageGenerator is created with condition options. Must use the methods with condition "
|
||||||
|
+ "options.");
|
||||||
|
}
|
||||||
|
cachedInputs = new CachedInputs();
|
||||||
|
cachedInputs.prompt = prompt;
|
||||||
|
cachedInputs.iterations = iterations;
|
||||||
|
cachedInputs.seed = seed;
|
||||||
|
cachedInputs.step = 0;
|
||||||
|
inProcessing = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the inputs of the ImageGenerator. For iterative usage, use {@link setInputs} and {@link
|
||||||
|
* execute} in pairs. Users must call {@link setInputs} before {@link execute}. Only valid when
|
||||||
|
* the ImageGenerator is created with condition options.
|
||||||
|
*
|
||||||
|
* @param prompt The text prompt describing the image to be generated.
|
||||||
|
* @param sourceConditionImage The source image used to create the condition image, which is used
|
||||||
|
* as a guidance for the image generation.
|
||||||
|
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
|
||||||
|
* condition image.
|
||||||
|
* @param iterations The total iterations to generate the image.
|
||||||
|
* @param seed The random seed used during image generation.
|
||||||
|
*/
|
||||||
|
public void setInputs(
|
||||||
|
String prompt,
|
||||||
|
MPImage sourceConditionImage,
|
||||||
|
ConditionOptions.ConditionType conditionType,
|
||||||
|
int iterations,
|
||||||
|
int seed) {
|
||||||
|
if (!useConditionImage) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"ImageGenerator is created without condition options. Must use the methods without "
|
||||||
|
+ "condition options.");
|
||||||
|
}
|
||||||
|
cachedInputs = new CachedInputs();
|
||||||
|
cachedInputs.prompt = prompt;
|
||||||
|
cachedInputs.iterations = iterations;
|
||||||
|
cachedInputs.seed = seed;
|
||||||
|
cachedInputs.step = 0;
|
||||||
|
cachedInputs.cachedConditionImage = createConditionImage(sourceConditionImage, conditionType);
|
||||||
|
cachedInputs.conditionType = conditionType;
|
||||||
|
inProcessing = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes one iteration of image generation. The method must be called {@code iterations} times
|
||||||
|
* to generate the final image. Must call {@link setInputs} before calling this method.
|
||||||
|
*
|
||||||
|
* <p>This is an iterative API, which must be called iteratively.
|
||||||
|
*
|
||||||
|
* <p>This API is useful for showing the intermediate image generation results and the image
|
||||||
|
* generation progress. Note that requesting the intermediate results will result in a larger
|
||||||
|
* latency. Consider using the e2e API instead for latency consideration.
|
||||||
|
*
|
||||||
|
* <p>Example usage:
|
||||||
|
*
|
||||||
|
* <p>imageGenerator.setInputs(prompt, iterations, seed); for (int step = 0; step < iterations;
|
||||||
|
* step++) { ImageGeneratorResult result = imageGenerator.execute(true); }
|
||||||
|
*
|
||||||
|
* @param showResult Whether to get the generated image result in the intermediate iterations. If
|
||||||
|
* false, null is returned. The generated image result is always returned at the last
|
||||||
|
* iteration, regardless of showResult value.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
public ImageGeneratorResult execute(boolean showResult) {
|
||||||
|
if (!inProcessing) {
|
||||||
|
throw new IllegalArgumentException("Must call setInputs before execute.");
|
||||||
|
}
|
||||||
|
return runStep(
|
||||||
|
cachedInputs.prompt,
|
||||||
|
cachedInputs.iterations,
|
||||||
|
cachedInputs.step++,
|
||||||
|
cachedInputs.seed,
|
||||||
|
cachedInputs.cachedConditionImage,
|
||||||
|
useConditionImage ? conditionTypeIndex.get(cachedInputs.conditionType) : 0,
|
||||||
|
showResult);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create the condition image of specified condition type from the source image. Currently support
|
* Create the condition image of specified condition type from the source image. Currently support
|
||||||
* face landmarks, depth image and edge image as the condition image.
|
* face landmarks, depth image and edge image as the condition image.
|
||||||
|
@ -295,6 +408,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
|
|
||||||
private ImageGeneratorResult runIterations(
|
private ImageGeneratorResult runIterations(
|
||||||
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
|
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
|
||||||
|
if (inProcessing) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Iterative API was called previously. It is not allowed to called batch API during"
|
||||||
|
+ "iterative processing.");
|
||||||
|
}
|
||||||
ImageGeneratorResult result = null;
|
ImageGeneratorResult result = null;
|
||||||
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
|
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
|
||||||
for (int i = 0; i < steps; i++) {
|
for (int i = 0; i < steps; i++) {
|
||||||
|
@ -318,12 +436,85 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
private ImageGeneratorResult runStep(
|
||||||
|
String prompt,
|
||||||
|
int iterations,
|
||||||
|
int step,
|
||||||
|
int seed,
|
||||||
|
@Nullable MPImage conditionImage,
|
||||||
|
int select,
|
||||||
|
boolean showResult) {
|
||||||
|
if (step == 0) {
|
||||||
|
cachedInputs.cachedTimestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
if (step == 0 && useConditionImage) {
|
||||||
|
inputPackets.put(
|
||||||
|
CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage));
|
||||||
|
inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select));
|
||||||
|
}
|
||||||
|
inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt));
|
||||||
|
inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(iterations));
|
||||||
|
inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(step));
|
||||||
|
inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed));
|
||||||
|
inputPackets.put(SHOW_RESULT_STREAM_NAME, runner.getPacketCreator().createBool(showResult));
|
||||||
|
ImageGeneratorResult result =
|
||||||
|
(ImageGeneratorResult) runner.process(inputPackets, cachedInputs.cachedTimestamp++);
|
||||||
|
if (result != null && useConditionImage) {
|
||||||
|
// Add condition image to the ImageGeneratorResult.
|
||||||
|
result =
|
||||||
|
ImageGeneratorResult.create(
|
||||||
|
result.generatedImage(), conditionImage, result.timestampMs());
|
||||||
|
}
|
||||||
|
if (step == iterations - 1) {
|
||||||
|
inProcessing = false;
|
||||||
|
cachedInputs = new CachedInputs();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/** Closes and cleans up the task runners. */
|
/** Closes and cleans up the task runners. */
|
||||||
@Override
|
@Override
|
||||||
public void close() {
|
public void close() {
|
||||||
|
if (runner != null) {
|
||||||
runner.close();
|
runner.close();
|
||||||
|
}
|
||||||
|
if (conditionImageGraphsContainerTaskRunner != null) {
|
||||||
conditionImageGraphsContainerTaskRunner.close();
|
conditionImageGraphsContainerTaskRunner.close();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper class to holder inputs to be checked with the inputs of next step.
|
||||||
|
private static class CachedInputs {
|
||||||
|
public CachedInputs() {
|
||||||
|
this.prompt = "";
|
||||||
|
this.iterations = 0;
|
||||||
|
this.step = 0;
|
||||||
|
this.seed = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return "Prompt: "
|
||||||
|
+ prompt
|
||||||
|
+ ", Iterations: "
|
||||||
|
+ iterations
|
||||||
|
+ ", Step: "
|
||||||
|
+ step
|
||||||
|
+ ", Seed: "
|
||||||
|
+ seed
|
||||||
|
+ (conditionType == null ? "" : ", ConditionType: " + conditionType.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
public String prompt;
|
||||||
|
public int iterations;
|
||||||
|
public int step;
|
||||||
|
public int seed;
|
||||||
|
public ConditionOptions.ConditionType conditionType;
|
||||||
|
public MPImage cachedConditionImage;
|
||||||
|
public long cachedTimestamp;
|
||||||
|
}
|
||||||
|
|
||||||
/** A container class for the condition image. */
|
/** A container class for the condition image. */
|
||||||
@AutoValue
|
@AutoValue
|
||||||
|
@ -343,15 +534,12 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
@AutoValue.Builder
|
@AutoValue.Builder
|
||||||
public abstract static class Builder {
|
public abstract static class Builder {
|
||||||
|
|
||||||
/** Sets the text to image model directory storing the model weights. */
|
/** Sets the image generator model directory storing the model weights. */
|
||||||
public abstract Builder setText2ImageModelDirectory(String modelDirectory);
|
public abstract Builder setImageGeneratorModelDirectory(String modelDirectory);
|
||||||
|
|
||||||
/** Sets the path to LoRA weights file. */
|
/** Sets the path to LoRA weights file. */
|
||||||
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
|
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
|
||||||
|
|
||||||
public abstract Builder setResultListener(
|
|
||||||
PureResultListener<ImageGeneratorResult> resultListener);
|
|
||||||
|
|
||||||
/** Sets an optional {@link ErrorListener}}. */
|
/** Sets an optional {@link ErrorListener}}. */
|
||||||
public abstract Builder setErrorListener(ErrorListener value);
|
public abstract Builder setErrorListener(ErrorListener value);
|
||||||
|
|
||||||
|
@ -363,19 +551,17 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract String text2ImageModelDirectory();
|
abstract String imageGeneratorModelDirectory();
|
||||||
|
|
||||||
abstract Optional<String> loraWeightsFilePath();
|
abstract Optional<String> loraWeightsFilePath();
|
||||||
|
|
||||||
abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener();
|
|
||||||
|
|
||||||
abstract Optional<ErrorListener> errorListener();
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
private Optional<ConditionOptions> conditionOptions;
|
private Optional<ConditionOptions> conditionOptions;
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
|
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
|
||||||
.setText2ImageModelDirectory("");
|
.setImageGeneratorModelDirectory("");
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
|
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
|
||||||
|
@ -393,13 +579,23 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory());
|
taskOptionsBuilder.setText2ImageModelDirectory(imageGeneratorModelDirectory());
|
||||||
if (loraWeightsFilePath().isPresent()) {
|
if (loraWeightsFilePath().isPresent()) {
|
||||||
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
|
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
|
||||||
ExternalFileProto.ExternalFile.newBuilder();
|
ExternalFileProto.ExternalFile.newBuilder();
|
||||||
externalFileBuilder.setFileName(loraWeightsFilePath().get());
|
externalFileBuilder.setFileName(loraWeightsFilePath().get());
|
||||||
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
|
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
|
||||||
}
|
}
|
||||||
|
taskOptionsBuilder.setStableDiffusionIterateOptions(
|
||||||
|
StableDiffusionIterateCalculatorOptionsProto.StableDiffusionIterateCalculatorOptions
|
||||||
|
.newBuilder()
|
||||||
|
.setBaseSeed(0)
|
||||||
|
.setFileFolder(imageGeneratorModelDirectory())
|
||||||
|
.setOutputImageWidth(GENERATED_IMAGE_WIDTH)
|
||||||
|
.setOutputImageHeight(GENERATED_IMAGE_HEIGHT)
|
||||||
|
.setEmitEmptyPacket(true)
|
||||||
|
.setShowEveryNIteration(100)
|
||||||
|
.build());
|
||||||
return Any.newBuilder()
|
return Any.newBuilder()
|
||||||
.setTypeUrl(
|
.setTypeUrl(
|
||||||
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
|
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
|
||||||
|
@ -465,7 +661,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||||
.setBaseOptions(
|
.setBaseOptions(
|
||||||
convertBaseOptionsToProto(faceConditionOptions().get().baseOptions()))
|
convertBaseOptionsToProto(
|
||||||
|
faceConditionOptions().get().pluginModelBaseOptions()))
|
||||||
.setConditionedImageGraphOptions(
|
.setConditionedImageGraphOptions(
|
||||||
ConditionedImageGraphOptions.newBuilder()
|
ConditionedImageGraphOptions.newBuilder()
|
||||||
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
|
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
|
||||||
|
@ -476,7 +673,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||||
.setBaseOptions(
|
.setBaseOptions(
|
||||||
convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions()))
|
convertBaseOptionsToProto(
|
||||||
|
edgeConditionOptions().get().pluginModelBaseOptions()))
|
||||||
.setConditionedImageGraphOptions(
|
.setConditionedImageGraphOptions(
|
||||||
ConditionedImageGraphOptions.newBuilder()
|
ConditionedImageGraphOptions.newBuilder()
|
||||||
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
|
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
|
||||||
|
@ -486,7 +684,8 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||||
.setBaseOptions(
|
.setBaseOptions(
|
||||||
convertBaseOptionsToProto(depthConditionOptions().get().baseOptions()))
|
convertBaseOptionsToProto(
|
||||||
|
depthConditionOptions().get().pluginModelBaseOptions()))
|
||||||
.setConditionedImageGraphOptions(
|
.setConditionedImageGraphOptions(
|
||||||
ConditionedImageGraphOptions.newBuilder()
|
ConditionedImageGraphOptions.newBuilder()
|
||||||
.setDepthConditionTypeOptions(
|
.setDepthConditionTypeOptions(
|
||||||
|
@ -510,11 +709,16 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
@AutoValue.Builder
|
@AutoValue.Builder
|
||||||
public abstract static class Builder {
|
public abstract static class Builder {
|
||||||
/** Set the base options for plugin model. */
|
/** Set the base options for plugin model. */
|
||||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
|
||||||
|
|
||||||
/* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */
|
/** Set base options for face landmarks model. */
|
||||||
public abstract Builder setFaceLandmarkerOptions(
|
public abstract Builder setFaceModelBaseOptions(BaseOptions baseOptions);
|
||||||
FaceLandmarkerOptions faceLandmarkerOptions);
|
|
||||||
|
/** Set minimum confidence score of face presence score in the face landmark detection. */
|
||||||
|
public abstract Builder setMinFaceDetectionConfidence(float minFaceDetectionConfidence);
|
||||||
|
|
||||||
|
/** Set the face presence threshold */
|
||||||
|
public abstract Builder setMinFacePresenceConfidence(float minFacePresenceConfidence);
|
||||||
|
|
||||||
abstract FaceConditionOptions autoBuild();
|
abstract FaceConditionOptions autoBuild();
|
||||||
|
|
||||||
|
@ -524,23 +728,36 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract BaseOptions baseOptions();
|
abstract BaseOptions pluginModelBaseOptions();
|
||||||
|
|
||||||
abstract FaceLandmarkerOptions faceLandmarkerOptions();
|
abstract BaseOptions faceModelBaseOptions();
|
||||||
|
|
||||||
|
abstract float minFaceDetectionConfidence();
|
||||||
|
|
||||||
|
abstract float minFacePresenceConfidence();
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder();
|
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder()
|
||||||
|
.setMinFaceDetectionConfidence(0.5f)
|
||||||
|
.setMinFacePresenceConfidence(0.5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
|
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
|
||||||
|
FaceLandmarkerOptions faceLandmarkerOptions =
|
||||||
|
FaceLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(faceModelBaseOptions())
|
||||||
|
.setMinFaceDetectionConfidence(minFaceDetectionConfidence())
|
||||||
|
.setMinFacePresenceConfidence(minFacePresenceConfidence())
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.setOutputFaceBlendshapes(false)
|
||||||
|
.setOutputFacialTransformationMatrixes(false)
|
||||||
|
.setNumFaces(1)
|
||||||
|
.build();
|
||||||
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
|
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
|
||||||
.setFaceLandmarkerGraphOptions(
|
.setFaceLandmarkerGraphOptions(
|
||||||
FaceLandmarkerGraphOptions.newBuilder()
|
faceLandmarkerOptions
|
||||||
.mergeFrom(
|
|
||||||
faceLandmarkerOptions()
|
|
||||||
.convertToCalculatorOptionsProto()
|
.convertToCalculatorOptionsProto()
|
||||||
.getExtension(FaceLandmarkerGraphOptions.ext))
|
.getExtension(FaceLandmarkerGraphOptions.ext))
|
||||||
.build())
|
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -553,12 +770,11 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
@AutoValue.Builder
|
@AutoValue.Builder
|
||||||
public abstract static class Builder {
|
public abstract static class Builder {
|
||||||
|
|
||||||
/** Set the base options for plugin model. */
|
/** Set the base options for the plugin model. */
|
||||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
|
||||||
|
|
||||||
/** {@link ImageSegmenterOptions} used to detect depth image from the source image. */
|
/** Set the base options for the depth model. */
|
||||||
public abstract Builder setImageSegmenterOptions(
|
public abstract Builder setDepthModelBaseOptions(BaseOptions baseOptions);
|
||||||
ImageSegmenterOptions imageSegmenterOptions);
|
|
||||||
|
|
||||||
abstract DepthConditionOptions autoBuild();
|
abstract DepthConditionOptions autoBuild();
|
||||||
|
|
||||||
|
@ -569,18 +785,25 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract BaseOptions baseOptions();
|
abstract BaseOptions pluginModelBaseOptions();
|
||||||
|
|
||||||
abstract ImageSegmenterOptions imageSegmenterOptions();
|
abstract BaseOptions depthModelBaseOptions();
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
|
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
|
||||||
}
|
}
|
||||||
|
|
||||||
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
|
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
|
||||||
|
ImageSegmenterOptions imageSegmenterOptions =
|
||||||
|
ImageSegmenterOptions.builder()
|
||||||
|
.setBaseOptions(depthModelBaseOptions())
|
||||||
|
.setOutputConfidenceMasks(true)
|
||||||
|
.setOutputCategoryMask(false)
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.build();
|
||||||
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
|
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
|
||||||
.setImageSegmenterGraphOptions(
|
.setImageSegmenterGraphOptions(
|
||||||
imageSegmenterOptions()
|
imageSegmenterOptions
|
||||||
.convertToCalculatorOptionsProto()
|
.convertToCalculatorOptionsProto()
|
||||||
.getExtension(ImageSegmenterGraphOptions.ext))
|
.getExtension(ImageSegmenterGraphOptions.ext))
|
||||||
.build();
|
.build();
|
||||||
|
@ -603,7 +826,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
public abstract static class Builder {
|
public abstract static class Builder {
|
||||||
|
|
||||||
/** Set the base options for plugin model. */
|
/** Set the base options for plugin model. */
|
||||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
public abstract Builder setPluginModelBaseOptions(BaseOptions baseOptions);
|
||||||
|
|
||||||
/** First threshold for the hysteresis procedure. */
|
/** First threshold for the hysteresis procedure. */
|
||||||
public abstract Builder setThreshold1(Float threshold1);
|
public abstract Builder setThreshold1(Float threshold1);
|
||||||
|
@ -629,7 +852,7 @@ public final class ImageGenerator extends BaseVisionTaskApi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract BaseOptions baseOptions();
|
abstract BaseOptions pluginModelBaseOptions();
|
||||||
|
|
||||||
abstract Float threshold1();
|
abstract Float threshold1();
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user