From 737c103940f474dc25d9fe70b19d78c3dbe41e5f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Aug 2023 10:42:14 -0700 Subject: [PATCH] Add output size as parameters in Java ImageSegmenter PiperOrigin-RevId: 558834692 --- .../mediapipe/framework/PacketCreator.java | 6 + .../framework/jni/packet_creator_jni.cc | 11 + .../framework/jni/packet_creator_jni.h | 3 + .../tasks/vision/core/BaseVisionTaskApi.java | 73 +++-- .../vision/imagesegmenter/ImageSegmenter.java | 310 +++++++++++++++--- .../imagesegmenter/ImageSegmenterTest.java | 1 + 6 files changed, 343 insertions(+), 61 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index 04265cab5..e71749d09 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -237,6 +237,10 @@ public class PacketCreator { return Packet.create(nativeCreateInt32Array(mediapipeGraph.getNativeHandle(), data)); } + public Packet createInt32Pair(int first, int second) { + return Packet.create(nativeCreateInt32Pair(mediapipeGraph.getNativeHandle(), first, second)); + } + public Packet createFloat32Array(float[] data) { return Packet.create(nativeCreateFloat32Array(mediapipeGraph.getNativeHandle(), data)); } @@ -449,6 +453,8 @@ public class PacketCreator { private native long nativeCreateInt32Array(long context, int[] data); + private native long nativeCreateInt32Pair(long context, int first, int second); + private native long nativeCreateFloat32Array(long context, float[] data); private native long nativeCreateFloat32Vector(long context, float[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index f7430e6e8..56ddd5e09 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -27,6 +28,7 @@ #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" @@ -481,6 +483,15 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Array)( return CreatePacketWithContext(context, packet); } +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Pair)( + JNIEnv* env, jobject thiz, jlong context, jint first, jint second) { + static_assert(std::is_same::value, "jint must be int32_t"); + + mediapipe::Packet packet = mediapipe::MakePacket>( + std::make_pair(first, second)); + return CreatePacketWithContext(context, packet); +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateStringFromByteArray)( JNIEnv* env, jobject thiz, jlong context, jbyteArray data) { jsize count = env->GetArrayLength(data); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index b3b1043fb..92f48261c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -118,6 +118,9 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloat32Vector)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Array)( JNIEnv* env, jobject thiz, jlong context, jintArray data); +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Pair)( + JNIEnv* env, jobject thiz, jlong context, jint first, jint second); + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateStringFromByteArray)( JNIEnv* env, jobject thiz, jlong context, jbyteArray data); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 9ea057b0d..0405e6dbf 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -27,7 +27,7 @@ import java.util.Map; /** The base class of MediaPipe vision tasks. */ public class BaseVisionTaskApi implements AutoCloseable { - private static final long MICROSECONDS_PER_MILLISECOND = 1000; + protected static final long MICROSECONDS_PER_MILLISECOND = 1000; protected final TaskRunner runner; protected final RunningMode runningMode; protected final String imageStreamName; @@ -69,12 +69,6 @@ public class BaseVisionTaskApi implements AutoCloseable { */ protected TaskResult processImageData( MPImage image, ImageProcessingOptions imageProcessingOptions) { - if (runningMode != RunningMode.IMAGE) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the image mode. Current running mode:" - + runningMode.name()); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); if (!normRectStreamName.isEmpty()) { @@ -84,6 +78,23 @@ public class BaseVisionTaskApi implements AutoCloseable { .getPacketCreator() .createProto(convertToNormalizedRect(imageProcessingOptions, image))); } + return processImageData(inputPackets); + } + + /** + * A synchronous method to process single image inputs. The call blocks the current thread until a + * failure status or a successful result is returned. + * + * @param inputPackets the maps of input stream names to the input packets. + * @throws MediaPipeException if the task is not in the image mode. + */ + protected TaskResult processImageData(Map inputPackets) { + if (runningMode != RunningMode.IMAGE) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the image mode. Current running mode:" + + runningMode.name()); + } return runner.process(inputPackets); } @@ -99,12 +110,6 @@ public class BaseVisionTaskApi implements AutoCloseable { */ protected TaskResult processVideoData( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { - if (runningMode != RunningMode.VIDEO) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the video mode. Current running mode:" - + runningMode.name()); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); if (!normRectStreamName.isEmpty()) { @@ -114,6 +119,24 @@ public class BaseVisionTaskApi implements AutoCloseable { .getPacketCreator() .createProto(convertToNormalizedRect(imageProcessingOptions, image))); } + return processVideoData(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + + /** + * A synchronous method to process continuous video frames. The call blocks the current thread + * until a failure status or a successful result is returned. + * + * @param inputPackets the maps of input stream names to the input packets. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the video mode. + */ + protected TaskResult processVideoData(Map inputPackets, long timestampMs) { + if (runningMode != RunningMode.VIDEO) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the video mode. Current running mode:" + + runningMode.name()); + } return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -129,12 +152,6 @@ public class BaseVisionTaskApi implements AutoCloseable { */ protected void sendLiveStreamData( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { - if (runningMode != RunningMode.LIVE_STREAM) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the live stream mode. Current running mode:" - + runningMode.name()); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); if (!normRectStreamName.isEmpty()) { @@ -144,6 +161,24 @@ public class BaseVisionTaskApi implements AutoCloseable { .getPacketCreator() .createProto(convertToNormalizedRect(imageProcessingOptions, image))); } + sendLiveStreamData(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + + /** + * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be + * available in the user-defined result listener. + * + * @param inputPackets the maps of input stream names to the input packets. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the stream mode. + */ + protected void sendLiveStreamData(Map inputPackets, long timestampMs) { + if (runningMode != RunningMode.LIVE_STREAM) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the live stream mode. Current running mode:" + + runningMode.name()); + } runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index f977c0159..3c9a135e9 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -43,7 +43,9 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; @@ -77,9 +79,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi { private static final String TAG = ImageSegmenter.class.getSimpleName(); private static final String IMAGE_IN_STREAM_NAME = "image_in"; private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final String OUTPUT_SIZE_IN_STREAM_NAME = "output_size_in"; private static final List INPUT_STREAMS = Collections.unmodifiableList( - Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + Arrays.asList( + "IMAGE:" + IMAGE_IN_STREAM_NAME, + "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME, + "OUTPUT_SIZE:" + OUTPUT_SIZE_IN_STREAM_NAME)); private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -238,6 +244,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { this.hasResultListener = hasResultListener; populateLabels(); } + /** * Populate the labelmap in TensorsToSegmentationCalculator to labels field. * @@ -275,9 +282,9 @@ public final class ImageSegmenter extends BaseVisionTaskApi { /** * Performs image segmentation on the provided single image with default image processing options, - * i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is - * created with {@link RunningMode.IMAGE}. TODO update java doc for input image - * format. + * i.e. without any rotation applied. The output mask has the same size as the input image. Only + * use this method when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. + * TODO update java doc for input image format. * *

{@link ImageSegmenter} supports the following color space types: * @@ -294,9 +301,9 @@ public final class ImageSegmenter extends BaseVisionTaskApi { } /** - * Performs image segmentation on the provided single image. Only use this method when the {@link - * ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java doc - * for input image format. + * Performs image segmentation on the provided single image. The output mask has the same size as + * the input image. Only use this method when the {@link ImageSegmenter} is created with {@link + * RunningMode.IMAGE}. TODO update java doc for input image format. * *

{@link ImageSegmenter} supports the following color space types: * @@ -316,21 +323,47 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public ImageSegmenterResult segment( MPImage image, ImageProcessingOptions imageProcessingOptions) { + return segment( + image, + SegmentationOptions.builder() + .setOutputWidth(image.getWidth()) + .setOutputHeight(image.getHeight()) + .setImageProcessingOptions(imageProcessingOptions) + .build()); + } + + /** + * Performs image segmentation on the provided single image. Only use this method when the {@link + * ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param segmentationOptions the {@link SegmentationOptions} used to configure the runtime + * behavior of the {@link ImageSegmenter}. + * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is + * created with a {@link ResultListener}. + */ + public ImageSegmenterResult segment(MPImage image, SegmentationOptions segmentationOptions) { if (hasResultListener) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "ResultListener is provided in the ImageSegmenterOptions, but this method will return an" + " ImageSegmentationResult."); } - validateImageProcessingOptions(imageProcessingOptions); - return (ImageSegmenterResult) processImageData(image, imageProcessingOptions); + return (ImageSegmenterResult) processImageData(buildInputPackets(image, segmentationOptions)); } /** * Performs image segmentation on the provided single image with default image processing options, * i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener} - * in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is - * created with {@link RunningMode.IMAGE}. + * in {@link ImageSegmenterOptions}. The output mask has the same size as the input image. Only + * use this method when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. * *

TODO update java doc for input image format. * @@ -341,8 +374,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * * * @param image a MediaPipe {@link MPImage} object for processing. - * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a - * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ @@ -352,8 +383,9 @@ public final class ImageSegmenter extends BaseVisionTaskApi { /** * Performs image segmentation on the provided single image, and provides zero-copied results via - * {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link - * ImageSegmenter} is created with {@link RunningMode.IMAGE}. + * {@link ResultListener} in {@link ImageSegmenterOptions}. The output mask has the same size as + * the input image. Only use this method when the {@link ImageSegmenter} is created with {@link + * RunningMode.IMAGE}. * *

TODO update java doc for input image format. * @@ -375,21 +407,53 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public void segmentWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions) { + segmentWithResultListener( + image, + SegmentationOptions.builder() + .setOutputWidth(image.getWidth()) + .setOutputHeight(image.getHeight()) + .setImageProcessingOptions(imageProcessingOptions) + .build()); + } + + /** + * Performs image segmentation on the provided single image, and provides zero-copied results via + * {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link + * ImageSegmenter} is created with {@link RunningMode.IMAGE}. + * + *

TODO update java doc for input image format. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param segmentationOptions the {@link SegmentationOptions} used to configure the runtime + * behavior of the {@link ImageSegmenter}. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. + */ + public void segmentWithResultListener(MPImage image, SegmentationOptions segmentationOptions) { if (!hasResultListener) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "ResultListener is not set in the ImageSegmenterOptions, but this method expects a" + " ResultListener to process ImageSegmentationResult."); } - validateImageProcessingOptions(imageProcessingOptions); ImageSegmenterResult unused = - (ImageSegmenterResult) processImageData(image, imageProcessingOptions); + (ImageSegmenterResult) processImageData(buildInputPackets(image, segmentationOptions)); } /** * Performs image segmentation on the provided video frame with default image processing options, - * i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is - * created with {@link RunningMode.VIDEO}. + * i.e. without any rotation applied. The output mask has the same size as the input image. Only + * use this method when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}. * *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps * must be monotonically increasing. @@ -410,8 +474,9 @@ public final class ImageSegmenter extends BaseVisionTaskApi { } /** - * Performs image segmentation on the provided video frame. Only use this method when the {@link - * ImageSegmenter} is created with {@link RunningMode.VIDEO}. + * Performs image segmentation on the provided video frame. The output mask has the same size as + * the input image. Only use this method when the {@link ImageSegmenter} is created with {@link + * RunningMode.VIDEO}. * *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps * must be monotonically increasing. @@ -435,21 +500,53 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public ImageSegmenterResult segmentForVideo( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return segmentForVideo( + image, + SegmentationOptions.builder() + .setOutputWidth(image.getWidth()) + .setOutputHeight(image.getHeight()) + .setImageProcessingOptions(imageProcessingOptions) + .build(), + timestampMs); + } + + /** + * Performs image segmentation on the provided video frame. Only use this method when the {@link + * ImageSegmenter} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param segmentationOptions the {@link SegmentationOptions} used to configure the runtime + * behavior of the {@link ImageSegmenter}. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is + * created with a {@link ResultListener}. + */ + public ImageSegmenterResult segmentForVideo( + MPImage image, SegmentationOptions segmentationOptions, long timestampMs) { if (hasResultListener) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "ResultListener is provided in the ImageSegmenterOptions, but this method will return an" + " ImageSegmentationResult."); } - validateImageProcessingOptions(imageProcessingOptions); - return (ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs); + return (ImageSegmenterResult) + processVideoData(buildInputPackets(image, segmentationOptions), timestampMs); } /** * Performs image segmentation on the provided video frame with default image processing options, * i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener} - * in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is - * created with {@link RunningMode.VIDEO}. + * in {@link ImageSegmenterOptions}. The output mask has the same size as the input image. Only + * use this method when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}. * *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps * must be monotonically increasing. @@ -469,6 +566,40 @@ public final class ImageSegmenter extends BaseVisionTaskApi { segmentForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs); } + /** + * Performs image segmentation on the provided video frame, and provides zero-copied results via + * {@link ResultListener} in {@link ImageSegmenterOptions}. The output mask has the same size as + * the input image. Only use this method when the {@link ImageSegmenter} is created with {@link + * RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. + */ + public void segmentForVideoWithResultListener( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + segmentForVideoWithResultListener( + image, + SegmentationOptions.builder() + .setOutputWidth(image.getWidth()) + .setOutputHeight(image.getHeight()) + .setImageProcessingOptions(imageProcessingOptions) + .build(), + timestampMs); + } + /** * Performs image segmentation on the provided video frame, and provides zero-copied results via * {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link @@ -484,28 +615,31 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * * * @param image a MediaPipe {@link MPImage} object for processing. + * @param segmentationOptions the {@link SegmentationOptions} used to configure the runtime + * behavior of the {@link ImageSegmenter}. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentForVideoWithResultListener( - MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + MPImage image, SegmentationOptions segmentationOptions, long timestampMs) { if (!hasResultListener) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "ResultListener is not set in the ImageSegmenterOptions, but this method expects a" + " ResultListener to process ImageSegmentationResult."); } - validateImageProcessingOptions(imageProcessingOptions); ImageSegmenterResult unused = - (ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs); + (ImageSegmenterResult) + processVideoData(buildInputPackets(image, segmentationOptions), timestampMs); } /** * Sends live image data to perform image segmentation with default image processing options, i.e. * without any rotation applied, and the results will be available via the {@link ResultListener} - * provided in the {@link ImageSegmenterOptions}. Only use this method when the {@link - * ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}. + * provided in the {@link ImageSegmenterOptions}. The output mask has the same size as the input + * image. Only use this method when the {@link ImageSegmenter } is created with {@link + * RunningMode.LIVE_STREAM}. * *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is * sent to the image segmenter. The input timestamps must be monotonically increasing. @@ -526,8 +660,9 @@ public final class ImageSegmenter extends BaseVisionTaskApi { /** * Sends live image data to perform image segmentation, and the results will be available via the - * {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when - * the {@link ImageSegmenter} is created with {@link RunningMode.LIVE_STREAM}. + * {@link ResultListener} provided in the {@link ImageSegmenterOptions}. The output mask has the + * same size as the input image. Only use this method when the {@link ImageSegmenter} is created + * with {@link RunningMode.LIVE_STREAM}. * *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is * sent to the image segmenter. The input timestamps must be monotonically increasing. @@ -550,8 +685,39 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public void segmentAsync( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { - validateImageProcessingOptions(imageProcessingOptions); - sendLiveStreamData(image, imageProcessingOptions, timestampMs); + segmentAsync( + image, + SegmentationOptions.builder() + .setOutputWidth(image.getWidth()) + .setOutputHeight(image.getHeight()) + .setImageProcessingOptions(imageProcessingOptions) + .build(), + timestampMs); + } + + /** + * Sends live image data to perform image segmentation, and the results will be available via the + * {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when + * the {@link ImageSegmenter} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the image segmenter. The input timestamps must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param segmentationOptions the {@link SegmentationOptions} used to configure the runtime + * behavior of the {@link ImageSegmenter}. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void segmentAsync( + MPImage image, SegmentationOptions segmentationOptions, long timestampMs) { + sendLiveStreamData(buildInputPackets(image, segmentationOptions), timestampMs); } /** @@ -565,6 +731,56 @@ public final class ImageSegmenter extends BaseVisionTaskApi { return labels; } + /** Options for configuring runtime behavior of {@link ImageSegmenter}. */ + @AutoValue + public abstract static class SegmentationOptions { + + /** Builder fo {@link SegmentationOptions} */ + @AutoValue.Builder + public abstract static class Builder { + + /** Set the width of the output segmentation masks. */ + public abstract Builder setOutputWidth(int value); + + /** Set the height of the output segmentation masks. */ + public abstract Builder setOutputHeight(int value); + + /** Set the image processing options. */ + public abstract Builder setImageProcessingOptions(ImageProcessingOptions value); + + abstract SegmentationOptions autoBuild(); + + /** + * Validates and builds the {@link SegmentationOptions} instance. + * + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + */ + public final SegmentationOptions build() { + SegmentationOptions options = autoBuild(); + if (options.outputWidth() <= 0 || options.outputHeight() <= 0) { + throw new IllegalArgumentException( + "Both outputWidth and outputHeight must be larger than 0."); + } + if (options.imageProcessingOptions().regionOfInterest().isPresent()) { + throw new IllegalArgumentException("ImageSegmenter doesn't support region-of-interest."); + } + return options; + } + } + + abstract int outputWidth(); + + abstract int outputHeight(); + + abstract ImageProcessingOptions imageProcessingOptions(); + + static Builder builder() { + return new AutoValue_ImageSegmenter_SegmentationOptions.Builder() + .setImageProcessingOptions(ImageProcessingOptions.builder().build()); + } + } + /** Options for setting up an {@link ImageSegmenter}. */ @AutoValue public abstract static class ImageSegmenterOptions extends TaskOptions { @@ -680,14 +896,24 @@ public final class ImageSegmenter extends BaseVisionTaskApi { } } - /** - * Validates that the provided {@link ImageProcessingOptions} doesn't contain a - * region-of-interest. - */ - private static void validateImageProcessingOptions( - ImageProcessingOptions imageProcessingOptions) { - if (imageProcessingOptions.regionOfInterest().isPresent()) { - throw new IllegalArgumentException("ImageSegmenter doesn't support region-of-interest."); + private Map buildInputPackets( + MPImage image, SegmentationOptions segmentationOptions) { + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + inputPackets.put( + OUTPUT_SIZE_IN_STREAM_NAME, + runner + .getPacketCreator() + .createInt32Pair( + segmentationOptions.outputWidth(), segmentationOptions.outputHeight())); + if (!normRectStreamName.isEmpty()) { + inputPackets.put( + normRectStreamName, + runner + .getPacketCreator() + .createProto( + convertToNormalizedRect(segmentationOptions.imageProcessingOptions(), image))); } + return inputPackets; } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java index 959f444cd..49ab0be13 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -31,6 +31,7 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions; +import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.SegmentationOptions; import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.FloatBuffer;