Update ImageSegmenter API for image/video mode to have both callback API and returned result API.

PiperOrigin-RevId: 512697585
This commit is contained in:
MediaPipe Team 2023-02-27 12:19:24 -08:00 committed by Copybara-Service
parent aa61abe386
commit a60d67eb10
2 changed files with 334 additions and 129 deletions

View File

@ -47,10 +47,14 @@ import java.util.Optional;
/**
* Performs image segmentation on images.
*
* <p>Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a
* user-defined callback function even for the synchronous API. This makes it possible for
* ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in
* the {@link ImageSegmenterOptions} for all {@link RunningMode}.
* <p>Note that, in addition to the standard segmentation API, {@link segment} and {@link
* segmentForVideo}, that take an input image and return the outputs, but involves deep copy of the
* returns, ImageSegmenter also supports the callback API, {@link segmentWithResultListener} and
* {@link segmentForVideoWithResultListener}, which allow you to access the outputs through zero
* copy.
*
* <p>The callback API is available for all {@link RunningMode} in ImageSegmenter. Set {@link
* ResultListener} in {@link ImageSegmenterOptions} properly to use the callback API.
*
* <p>The API expects a TFLite model with,<a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
@ -85,6 +89,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
private boolean hasResultListener = false;
/**
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
*
@ -116,8 +122,19 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
int imageListSize =
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
if (!segmenterOptions.resultListener().isPresent()) {
for (int i = 0; i < imageListSize; i++) {
buffersArray[i] =
ByteBuffer.allocateDirect(
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
}
}
if (!PacketGetter.getImageList(
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) {
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
buffersArray,
!segmenterOptions.resultListener().isPresent())) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting segmented masks. It usually results from incorrect"
@ -143,7 +160,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
.build();
}
});
handler.setResultListener(segmenterOptions.resultListener());
segmenterOptions.resultListener().ifPresent(handler::setResultListener);
segmenterOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
@ -158,7 +175,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
return new ImageSegmenter(runner, segmenterOptions.runningMode());
return new ImageSegmenter(
runner, segmenterOptions.runningMode(), segmenterOptions.resultListener().isPresent());
}
/**
@ -168,16 +186,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) {
private ImageSegmenter(
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
this.hasResultListener = hasResultListener;
}
/**
* Performs image segmentation on the provided single image with default image processing options,
* i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
* {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java
* doc for input image format.
* i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is
* created with {@link RunningMode.IMAGE}. TODO update java doc for input image
* format.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
@ -186,19 +205,19 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/
public void segment(MPImage image) {
segment(image, ImageProcessingOptions.builder().build());
public ImageSegmenterResult segment(MPImage image) {
return segment(image, ImageProcessingOptions.builder().build());
}
/**
* Performs image segmentation on the provided single image, and the results will be available via
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
* when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO
* update java doc for input image format.
* Performs image segmentation on the provided single image. Only use this method when the {@link
* ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java doc
* for input image format.
*
* <p>{@link HandLandmarker} supports the following color space types:
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
@ -211,19 +230,84 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/
public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) {
public ImageSegmenterResult segment(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
if (hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is provided in the ImageSegmenterOptions, but this method will return an"
+ " ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
ImageSegmenterResult unused =
(ImageSegmenterResult) processImageData(image, imageProcessingOptions);
return (ImageSegmenterResult) processImageData(image, imageProcessingOptions);
}
/**
* Performs image segmentation on the provided single image with default image processing options,
* i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener}
* in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is
* created with {@link RunningMode.IMAGE}.
*
* <p>TODO update java doc for input image format.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentWithResultListener(MPImage image) {
segmentWithResultListener(image, ImageProcessingOptions.builder().build());
}
/**
* Performs image segmentation on the provided single image, and provides zero-copied results via
* {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link
* ImageSegmenter} is created with {@link RunningMode.IMAGE}.
*
* <p>TODO update java doc for input image format.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentWithResultListener(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
if (!hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is not set in the ImageSegmenterOptions, but this method expects a"
+ " ResultListener to process ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
var unused = processImageData(image, imageProcessingOptions);
}
/**
* Performs image segmentation on the provided video frame with default image processing options,
* i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
* i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is
* created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
@ -236,21 +320,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/
public void segmentForVideo(MPImage image, long timestampMs) {
segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
public ImageSegmenterResult segmentForVideo(MPImage image, long timestampMs) {
return segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Performs image segmentation on the provided video frame, and the results will be available via
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
* when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}.
* Performs image segmentation on the provided video frame. Only use this method when the {@link
* ImageSegmenter} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link HandLandmarker} supports the following color space types:
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
@ -264,20 +348,81 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
* created with a {@link ResultListener}.
*/
public void segmentForVideo(
public ImageSegmenterResult segmentForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is provided in the ImageSegmenterOptions, but this method will return an"
+ " ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
ImageSegmenterResult unused =
(ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
return (ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform hand landmarks detection with default image processing
* options, i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
* {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
* Performs image segmentation on the provided video frame with default image processing options,
* i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener}
* in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is
* created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentForVideoWithResultListener(MPImage image, long timestampMs) {
segmentForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Performs image segmentation on the provided video frame, and provides zero-copied results via
* {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link
* ImageSegmenter} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
*/
public void segmentForVideoWithResultListener(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (!hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is not set in the ImageSegmenterOptions, but this method expects a"
+ " ResultListener to process ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
var unused = processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform image segmentation with default image processing options, i.e.
* without any rotation applied, and the results will be available via the {@link ResultListener}
* provided in the {@link ImageSegmenterOptions}. Only use this method when the {@link
* ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the image segmenter. The input timestamps must be monotonically increasing.
@ -360,8 +505,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
public abstract Builder setOutputType(OutputType value);
/**
* Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline
* is done processing an image.
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
* pipeline is done processing an image.
*/
public abstract Builder setResultListener(
ResultListener<ImageSegmenterResult, MPImage> value);
@ -375,11 +520,18 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
* Validates and builds the {@link ImageSegmenterOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the image segmenter is
* in the live stream mode.
* properly configured. The result listener must be set when the image segmenter is in the
* live stream mode.
*/
public final ImageSegmenterOptions build() {
ImageSegmenterOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The image segmenter is in the live stream mode, a user-defined result listener"
+ " must be provided in ImageSegmenterOptions.");
}
}
return options;
}
}
@ -392,7 +544,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
abstract OutputType outputType();
abstract ResultListener<ImageSegmenterResult, MPImage> resultListener();
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener();
@ -410,8 +562,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
.setRunningMode(RunningMode.IMAGE)
.setDisplayNamesLocale("en")
.setOutputType(OutputType.CATEGORY_MASK)
.setResultListener((result, image) -> {});
.setOutputType(OutputType.CATEGORY_MASK);
}
/**
@ -437,6 +588,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
// TODO: remove this once activation is handled in metadata and grpah level.
segmenterOptionsBuilder.setActivation(
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);

View File

@ -53,112 +53,108 @@ public class ImageSegmenterTest {
@RunWith(AndroidJUnit4.class)
public static final class General extends ImageSegmenterTest {
@Test
public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = "segmentation_input_rotation0.jpg";
final String goldenImageName = "segmentation_golden_rotation0.png";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(1);
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
verifyCategoryMask(
actualMaskBuffer,
expectedMaskBuffer,
GOLDEN_MASK_SIMILARITY,
MAGNIFICATION_FACTOR);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(1);
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyCategoryMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
}
@Test
public void segment_successWithConfidenceMask() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
}
@Test
public void segment_successWith128x128Segmentation() throws Exception {
final String inputImageName = "mozart_square.jpg";
final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(2);
// Selfie category index 1.
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(2);
// Selfie category index 1.
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
}
// TODO: enable this unit test once activation option is supported in metadata.
// @Test
// public void segment_successWith144x256Segmentation() throws Exception {
// final String inputImageName = "mozart_square.jpg";
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// ImageSegmenterOptions options =
// ImageSegmenterOptions.builder()
// .setBaseOptions(
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
// .setActivation(ImageSegmenterOptions.Activation.NONE)
// .setResultListener(
// (actualResult, inputImage) -> {
// List<MPImage> segmentations = actualResult.segmentations();
// assertThat(segmentations.size()).isEqualTo(1);
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
// verifyConfidenceMask(
// actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// })
// .build();
// ImageSegmenter imageSegmenter =
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(),
// options);
// imageSegmenter.segment(getImageFromAsset(inputImageName));
// }
// @Test
// public void segment_successWith144x256Segmentation() throws Exception {
// final String inputImageName = "mozart_square.jpg";
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
// ImageSegmenterOptions options =
// ImageSegmenterOptions.builder()
// .setBaseOptions(
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
// .build();
// ImageSegmenter imageSegmenter =
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// ImageSegmenterResult actualResult =
// imageSegmenter.segment(getImageFromAsset(inputImageName));
// List<MPImage> segmentations = actualResult.segmentations();
// assertThat(segmentations.size()).isEqualTo(1);
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// }
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends ImageSegmenterTest {
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageSegmenterOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test
public void segment_failsWithCallingWrongApiInImageMode() throws Exception {
ImageSegmenterOptions options =
@ -166,7 +162,6 @@ public class ImageSegmenterTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
@ -182,6 +177,13 @@ public class ImageSegmenterTest {
() ->
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
exception =
assertThrows(
MediaPipeException.class,
() -> imageSegmenter.segmentWithResultListener(getImageFromAsset(CAT_IMAGE)));
assertThat(exception)
.hasMessageThat()
.contains("ResultListener is not set in the ImageSegmenterOptions");
}
@Test
@ -191,7 +193,6 @@ public class ImageSegmenterTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
@ -204,6 +205,15 @@ public class ImageSegmenterTest {
() ->
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
imageSegmenter.segmentForVideoWithResultListener(
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception)
.hasMessageThat()
.contains("ResultListener is not set in the ImageSegmenterOptions");
}
@Test
@ -214,18 +224,18 @@ public class ImageSegmenterTest {
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((result, inputImage) -> {})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
MediaPipeException.class,
() -> imageSegmenter.segmentWithResultListener(getImageFromAsset(CAT_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
imageSegmenter.segmentForVideo(
imageSegmenter.segmentForVideoWithResultListener(
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@ -234,51 +244,94 @@ public class ImageSegmenterTest {
public void segment_successWithImageMode() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.IMAGE)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
}
@Test
public void segment_successWithImageModeWithResultListener() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedResult = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.IMAGE)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
(segmenterResult, inputImage) -> {
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
segmenterResult.segmentations().get(8),
expectedResult,
GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
imageSegmenter.segmentWithResultListener(getImageFromAsset(inputImageName));
}
@Test
public void segment_successWithVideoMode() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.VIDEO)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
for (int i = 0; i < 3; i++) {
ImageSegmenterResult actualResult =
imageSegmenter.segmentForVideo(
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
}
}
@Test
public void segment_successWithVideoModeWithResultListener() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedResult = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.VIDEO)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
(segmenterResult, inputImage) -> {
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
segmenterResult.segmentations().get(8),
expectedResult,
GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i);
imageSegmenter.segmentForVideoWithResultListener(
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
}
}