Update ImageSegmenter API for image/video mode to have both callback API and returned result API.

PiperOrigin-RevId: 512697585
This commit is contained in:
MediaPipe Team 2023-02-27 12:19:24 -08:00 committed by Copybara-Service
parent aa61abe386
commit a60d67eb10
2 changed files with 334 additions and 129 deletions

View File

@ -47,10 +47,14 @@ import java.util.Optional;
/** /**
* Performs image segmentation on images. * Performs image segmentation on images.
* *
* <p>Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a * <p>Note that, in addition to the standard segmentation API, {@link segment} and {@link
* user-defined callback function even for the synchronous API. This makes it possible for * segmentForVideo}, that take an input image and return the outputs, but involves deep copy of the
* ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in * returns, ImageSegmenter also supports the callback API, {@link segmentWithResultListener} and
* the {@link ImageSegmenterOptions} for all {@link RunningMode}. * {@link segmentForVideoWithResultListener}, which allow you to access the outputs through zero
* copy.
*
* <p>The callback API is available for all {@link RunningMode} in ImageSegmenter. Set {@link
* ResultListener} in {@link ImageSegmenterOptions} properly to use the callback API.
* *
* <p>The API expects a TFLite model with,<a * <p>The API expects a TFLite model with,<a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
@ -85,6 +89,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
private boolean hasResultListener = false;
/** /**
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
* *
@ -116,8 +122,19 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
int imageListSize = int imageListSize =
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
if (!segmenterOptions.resultListener().isPresent()) {
for (int i = 0; i < imageListSize; i++) {
buffersArray[i] =
ByteBuffer.allocateDirect(
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
}
}
if (!PacketGetter.getImageList( if (!PacketGetter.getImageList(
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) { packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
buffersArray,
!segmenterOptions.resultListener().isPresent())) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting segmented masks. It usually results from incorrect" "There is an error getting segmented masks. It usually results from incorrect"
@ -143,7 +160,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
.build(); .build();
} }
}); });
handler.setResultListener(segmenterOptions.resultListener()); segmenterOptions.resultListener().ifPresent(handler::setResultListener);
segmenterOptions.errorListener().ifPresent(handler::setErrorListener); segmenterOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner = TaskRunner runner =
TaskRunner.create( TaskRunner.create(
@ -158,7 +175,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(), .build(),
handler); handler);
return new ImageSegmenter(runner, segmenterOptions.runningMode()); return new ImageSegmenter(
runner, segmenterOptions.runningMode(), segmenterOptions.resultListener().isPresent());
} }
/** /**
@ -168,16 +186,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* @param taskRunner a {@link TaskRunner}. * @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}. * @param runningMode a mediapipe vision task {@link RunningMode}.
*/ */
private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) { private ImageSegmenter(
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
this.hasResultListener = hasResultListener;
} }
/** /**
* Performs image segmentation on the provided single image with default image processing options, * Performs image segmentation on the provided single image with default image processing options,
* i.e. without any rotation applied, and the results will be available via the {@link * i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the * created with {@link RunningMode.IMAGE}. TODO update java doc for input image
* {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java * format.
* doc for input image format.
* *
* <p>{@link ImageSegmenter} supports the following color space types: * <p>{@link ImageSegmenter} supports the following color space types:
* *
@ -186,19 +205,19 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* </ul> * </ul>
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/ */
public void segment(MPImage image) { public ImageSegmenterResult segment(MPImage image) {
segment(image, ImageProcessingOptions.builder().build()); return segment(image, ImageProcessingOptions.builder().build());
} }
/** /**
* Performs image segmentation on the provided single image, and the results will be available via * Performs image segmentation on the provided single image. Only use this method when the {@link
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method * ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java doc
* when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO * for input image format.
* update java doc for input image format.
* *
* <p>{@link HandLandmarker} supports the following color space types: * <p>{@link ImageSegmenter} supports the following color space types:
* *
* <ul> * <ul>
* <li>{@link Bitmap.Config.ARGB_8888} * <li>{@link Bitmap.Config.ARGB_8888}
@ -211,19 +230,84 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* this method throwing an IllegalArgumentException. * this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest. * region-of-interest.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/ */
public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) { public ImageSegmenterResult segment(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
if (hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is provided in the ImageSegmenterOptions, but this method will return an"
+ " ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions); validateImageProcessingOptions(imageProcessingOptions);
ImageSegmenterResult unused = return (ImageSegmenterResult) processImageData(image, imageProcessingOptions);
(ImageSegmenterResult) processImageData(image, imageProcessingOptions); }
/**
* Performs image segmentation on the provided single image with default image processing options,
* i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener}
* in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is
* created with {@link RunningMode.IMAGE}.
*
* <p>TODO update java doc for input image format.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentWithResultListener(MPImage image) {
segmentWithResultListener(image, ImageProcessingOptions.builder().build());
}
/**
* Performs image segmentation on the provided single image, and provides zero-copied results via
* {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link
* ImageSegmenter} is created with {@link RunningMode.IMAGE}.
*
* <p>TODO update java doc for input image format.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentWithResultListener(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
if (!hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is not set in the ImageSegmenterOptions, but this method expects a"
+ " ResultListener to process ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
var unused = processImageData(image, imageProcessingOptions);
} }
/** /**
* Performs image segmentation on the provided video frame with default image processing options, * Performs image segmentation on the provided video frame with default image processing options,
* i.e. without any rotation applied, and the results will be available via the {@link * i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the * created with {@link RunningMode.VIDEO}.
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
* *
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps * <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing. * must be monotonically increasing.
@ -236,21 +320,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/ */
public void segmentForVideo(MPImage image, long timestampMs) { public ImageSegmenterResult segmentForVideo(MPImage image, long timestampMs) {
segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); return segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
} }
/** /**
* Performs image segmentation on the provided video frame, and the results will be available via * Performs image segmentation on the provided video frame. Only use this method when the {@link
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method * ImageSegmenter} is created with {@link RunningMode.VIDEO}.
* when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}.
* *
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps * <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing. * must be monotonically increasing.
* *
* <p>{@link HandLandmarker} supports the following color space types: * <p>{@link ImageSegmenter} supports the following color space types:
* *
* <ul> * <ul>
* <li>{@link Bitmap.Config.ARGB_8888} * <li>{@link Bitmap.Config.ARGB_8888}
@ -264,20 +348,81 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest. * region-of-interest.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/ */
public void segmentForVideo( public ImageSegmenterResult segmentForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is provided in the ImageSegmenterOptions, but this method will return an"
+ " ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions); validateImageProcessingOptions(imageProcessingOptions);
ImageSegmenterResult unused = return (ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
(ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
} }
/** /**
* Sends live image data to perform hand landmarks detection with default image processing * Performs image segmentation on the provided video frame with default image processing options,
* options, i.e. without any rotation applied, and the results will be available via the {@link * i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener}
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the * in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is
* {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}. * created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentForVideoWithResultListener(MPImage image, long timestampMs) {
segmentForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Performs image segmentation on the provided video frame, and provides zero-copied results via
* {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link
* ImageSegmenter} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentForVideoWithResultListener(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (!hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is not set in the ImageSegmenterOptions, but this method expects a"
+ " ResultListener to process ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
var unused = processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform image segmentation with default image processing options, i.e.
* without any rotation applied, and the results will be available via the {@link ResultListener}
* provided in the {@link ImageSegmenterOptions}. Only use this method when the {@link
* ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
* *
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is * <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the image segmenter. The input timestamps must be monotonically increasing. * sent to the image segmenter. The input timestamps must be monotonically increasing.
@ -360,8 +505,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
public abstract Builder setOutputType(OutputType value); public abstract Builder setOutputType(OutputType value);
/** /**
* Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
* is done processing an image. * pipeline is done processing an image.
*/ */
public abstract Builder setResultListener( public abstract Builder setResultListener(
ResultListener<ImageSegmenterResult, MPImage> value); ResultListener<ImageSegmenterResult, MPImage> value);
@ -375,11 +520,18 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* Validates and builds the {@link ImageSegmenterOptions} instance. * Validates and builds the {@link ImageSegmenterOptions} instance.
* *
* @throws IllegalArgumentException if the result listener and the running mode are not * @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the image segmenter is * properly configured. The result listener must be set when the image segmenter is in the
* in the live stream mode. * live stream mode.
*/ */
public final ImageSegmenterOptions build() { public final ImageSegmenterOptions build() {
ImageSegmenterOptions options = autoBuild(); ImageSegmenterOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The image segmenter is in the live stream mode, a user-defined result listener"
+ " must be provided in ImageSegmenterOptions.");
}
}
return options; return options;
} }
} }
@ -392,7 +544,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
abstract OutputType outputType(); abstract OutputType outputType();
abstract ResultListener<ImageSegmenterResult, MPImage> resultListener(); abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();
@ -410,8 +562,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.setDisplayNamesLocale("en") .setDisplayNamesLocale("en")
.setOutputType(OutputType.CATEGORY_MASK) .setOutputType(OutputType.CATEGORY_MASK);
.setResultListener((result, image) -> {});
} }
/** /**
@ -437,6 +588,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
segmenterOptionsBuilder.setOutputType( segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
} }
// TODO: remove this once activation is handled in metadata and grpah level. // TODO: remove this once activation is handled in metadata and grpah level.
segmenterOptionsBuilder.setActivation( segmenterOptionsBuilder.setActivation(
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX); SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);

View File

@ -53,112 +53,108 @@ public class ImageSegmenterTest {
@RunWith(AndroidJUnit4.class) @RunWith(AndroidJUnit4.class)
public static final class General extends ImageSegmenterTest { public static final class General extends ImageSegmenterTest {
@Test @Test
public void segment_successWithCategoryMask() throws Exception { public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = "segmentation_input_rotation0.jpg"; final String inputImageName = "segmentation_input_rotation0.jpg";
final String goldenImageName = "segmentation_golden_rotation0.png"; final String goldenImageName = "segmentation_golden_rotation0.png";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(1);
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
verifyCategoryMask(
actualMaskBuffer,
expectedMaskBuffer,
GOLDEN_MASK_SIMILARITY,
MAGNIFICATION_FACTOR);
})
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName)); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(1);
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyCategoryMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
} }
@Test @Test
public void segment_successWithConfidenceMask() throws Exception { public void segment_successWithConfidenceMask() throws Exception {
final String inputImageName = "cat.jpg"; final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg"; final String goldenImageName = "cat_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName)); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
} }
@Test @Test
public void segment_successWith128x128Segmentation() throws Exception { public void segment_successWith128x128Segmentation() throws Exception {
final String inputImageName = "mozart_square.jpg"; final String inputImageName = "mozart_square.jpg";
final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg"; final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(2);
// Selfie category index 1.
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName)); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(2);
// Selfie category index 1.
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
} }
// TODO: enable this unit test once activation option is supported in metadata. // TODO: enable this unit test once activation option is supported in metadata.
// @Test // @Test
// public void segment_successWith144x256Segmentation() throws Exception { // public void segment_successWith144x256Segmentation() throws Exception {
// final String inputImageName = "mozart_square.jpg"; // final String inputImageName = "mozart_square.jpg";
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; // final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); // ImageSegmenterOptions options =
// ImageSegmenterOptions options = // ImageSegmenterOptions.builder()
// ImageSegmenterOptions.builder() // .setBaseOptions(
// .setBaseOptions( // BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) // .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) // .build();
// .setActivation(ImageSegmenterOptions.Activation.NONE) // ImageSegmenter imageSegmenter =
// .setResultListener( // ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// (actualResult, inputImage) -> { // ImageSegmenterResult actualResult =
// List<MPImage> segmentations = actualResult.segmentations(); // imageSegmenter.segment(getImageFromAsset(inputImageName));
// assertThat(segmentations.size()).isEqualTo(1); // List<MPImage> segmentations = actualResult.segmentations();
// MPImage actualMaskBuffer = actualResult.segmentations().get(0); // assertThat(segmentations.size()).isEqualTo(1);
// verifyConfidenceMask( // MPImage actualMaskBuffer = actualResult.segmentations().get(0);
// actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// }) // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// .build(); // }
// ImageSegmenter imageSegmenter =
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(),
// options);
// imageSegmenter.segment(getImageFromAsset(inputImageName));
// }
} }
@RunWith(AndroidJUnit4.class) @RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends ImageSegmenterTest { public static final class RunningModeTest extends ImageSegmenterTest {
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageSegmenterOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test @Test
public void segment_failsWithCallingWrongApiInImageMode() throws Exception { public void segment_failsWithCallingWrongApiInImageMode() throws Exception {
ImageSegmenterOptions options = ImageSegmenterOptions options =
@ -166,7 +162,6 @@ public class ImageSegmenterTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception = MediaPipeException exception =
@ -182,6 +177,13 @@ public class ImageSegmenterTest {
() -> () ->
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
exception =
assertThrows(
MediaPipeException.class,
() -> imageSegmenter.segmentWithResultListener(getImageFromAsset(CAT_IMAGE)));
assertThat(exception)
.hasMessageThat()
.contains("ResultListener is not set in the ImageSegmenterOptions");
} }
@Test @Test
@ -191,7 +193,6 @@ public class ImageSegmenterTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception = MediaPipeException exception =
@ -204,6 +205,15 @@ public class ImageSegmenterTest {
() -> () ->
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
imageSegmenter.segmentForVideoWithResultListener(
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception)
.hasMessageThat()
.contains("ResultListener is not set in the ImageSegmenterOptions");
} }
@Test @Test
@ -214,18 +224,18 @@ public class ImageSegmenterTest {
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((result, inputImage) -> {}) .setResultListener((result, inputImage) -> {})
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception = MediaPipeException exception =
assertThrows( assertThrows(
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE))); MediaPipeException.class,
() -> imageSegmenter.segmentWithResultListener(getImageFromAsset(CAT_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> () ->
imageSegmenter.segmentForVideo( imageSegmenter.segmentForVideoWithResultListener(
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
} }
@ -234,51 +244,94 @@ public class ImageSegmenterTest {
public void segment_successWithImageMode() throws Exception { public void segment_successWithImageMode() throws Exception {
final String inputImageName = "cat.jpg"; final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg"; final String goldenImageName = "cat_mask.jpg";
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.IMAGE)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
}
@Test
public void segment_successWithImageModeWithResultListener() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedResult = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.setResultListener( .setResultListener(
(actualResult, inputImage) -> { (segmenterResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask( verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); segmenterResult.segmentations().get(8),
expectedResult,
GOLDEN_MASK_SIMILARITY);
}) })
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName)); imageSegmenter.segmentWithResultListener(getImageFromAsset(inputImageName));
} }
@Test @Test
public void segment_successWithVideoMode() throws Exception { public void segment_successWithVideoMode() throws Exception {
final String inputImageName = "cat.jpg"; final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg"; final String goldenImageName = "cat_mask.jpg";
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.VIDEO)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
for (int i = 0; i < 3; i++) {
ImageSegmenterResult actualResult =
imageSegmenter.segmentForVideo(
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
}
}
@Test
public void segment_successWithVideoModeWithResultListener() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedResult = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.setResultListener( .setResultListener(
(actualResult, inputImage) -> { (segmenterResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask( verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); segmenterResult.segmentations().get(8),
expectedResult,
GOLDEN_MASK_SIMILARITY);
}) })
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i); imageSegmenter.segmentForVideoWithResultListener(
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
} }
} }